Run black autoformatter

This commit is contained in:
2024-05-30 13:18:51 +02:00
parent 2f3c34b8b5
commit fcd7d96126
7 changed files with 110 additions and 85 deletions

View File

@@ -2,6 +2,7 @@ import torch
from utils import export_model_weights, print_cpp_vector
class TestModel(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
@@ -13,24 +14,18 @@ class TestModel(torch.nn.Module):
kernel_size=3,
stride=1,
padding=0,
bias=False
)
self.maxpool1 = torch.nn.MaxPool2d(
kernel_size=2,
stride=2
bias=False,
)
self.maxpool1 = torch.nn.MaxPool2d(kernel_size=2, stride=2)
self.activation = torch.nn.ReLU()
self.linear = torch.nn.Linear(
in_features=8,
out_features=3,
bias=False
)
self.linear = torch.nn.Linear(in_features=8, out_features=3, bias=False)
self.softmax = torch.nn.Softmax(dim=0)
def set_weights(self):
# fmt: off
conv2d_weights = torch.tensor([
0.18313, 0.53363, 0.39527, 0.27575, 0.3433, 0.41746,
0.16831, 0.61693, 0.54599, 0.99692, 0.77127, 0.25146,
@@ -40,8 +35,10 @@ class TestModel(torch.nn.Module):
0.68407, 0.2684, 0.2855, 0.76195, 0.67828, 0.603
]).reshape(2, 2, 3, 3)
# fmt: on
self.conv1.weight = torch.nn.Parameter(conv2d_weights)
# fmt: off
linear_weights = torch.tensor([
0.36032, 0.33115, 0.02948,
0.09802, 0.45072, 0.56266,
@@ -52,6 +49,7 @@ class TestModel(torch.nn.Module):
0.51559, 0.81916, 0.64915,
0.03934, 0.87608, 0.68364,
]).reshape(3, 8)
# fmt: on
self.linear.weight = torch.nn.Parameter(linear_weights)
def forward(self, x):
@@ -64,11 +62,13 @@ class TestModel(torch.nn.Module):
x = self.softmax(x)
return x
if __name__ == "__main__":
model = TestModel()
model.set_weights()
# fmt: off
input = torch.tensor([
0.12762, 0.99056, 0.77565, 0.29058, 0.29787, 0.58415, 0.20484,
0.05415, 0.60593, 0.3162, 0.08198, 0.92749, 0.72392, 0.91786,
@@ -82,14 +82,14 @@ if __name__ == "__main__":
0.84854, 0.61415, 0.2466, 0.20017, 0.78952, 0.93797, 0.27884,
0.30514, 0.23521
]).reshape(2, 6, 6)
# input = torch.rand(2, 6, 6)
# fmt: on
print("Single test output:")
out = model(input)
print_cpp_vector(out)
print("Multiple predict test output 1:")
# fmt: off
input = torch.tensor([
0.81247, 0.03579, 0.26577, 0.80374, 0.64584, 0.19658, 0.04817,
0.50769, 0.33502, 0.01739, 0.32263, 0.69625, 0.07433, 0.98283,
@@ -103,10 +103,12 @@ if __name__ == "__main__":
0.16811, 0.72188, 0.08683, 0.66985, 0.62707, 0.4035, 0.51822,
0.46545, 0.88722
]).reshape(2, 6, 6)
# fmt: on
out = model(input)
print_cpp_vector(out)
print("Multiple predict test output 2:")
# fmt: off
input = torch.tensor([
0.83573, 0.19191, 0.16004, 0.27137, 0.64768, 0.38417, 0.02167,
0.28834, 0.21401, 0.16624, 0.12037, 0.12706, 0.3588, 0.10685,
@@ -120,6 +122,7 @@ if __name__ == "__main__":
0.66075, 0.64496, 0.1191, 0.66261, 0.63431, 0.7137, 0.14851,
0.84456, 0.44482
]).reshape(2, 6, 6)
# fmt: on
out = model(input)
print_cpp_vector(out)