Run black autoformatter

This commit is contained in:
2024-05-30 13:18:51 +02:00
parent 2f3c34b8b5
commit fcd7d96126
7 changed files with 110 additions and 85 deletions

View File

@@ -2,12 +2,13 @@ import torch
from utils import print_cpp_vector
def gen_batch_norm_test_result(input):
batch_norm = torch.nn.BatchNorm2d(2, track_running_stats=False)
weights = torch.Tensor([0.63508, 0.64903])
biases= torch.Tensor([0.25079, 0.66841])
biases = torch.Tensor([0.25079, 0.66841])
batch_norm.weight = torch.nn.Parameter(weights)
batch_norm.bias = torch.nn.Parameter(biases)
@@ -15,11 +16,13 @@ def gen_batch_norm_test_result(input):
output = batch_norm(input)
print_cpp_vector(output.flatten())
if __name__ == "__main__":
print("Generating test results...")
print("Batch norm test:")
# fmt: off
input = torch.Tensor([
# Channel 0
0.38899, 0.80478, 0.48836, 0.97381,
@@ -32,11 +35,13 @@ if __name__ == "__main__":
0.13449, 0.27367, 0.53036, 0.18962,
0.57672, 0.48364, 0.10863, 0.0571
]).reshape(1, 2, 4, 4)
# fmt: on
gen_batch_norm_test_result(input)
print("Batch norm test non square input:")
# fmt: off
input = torch.Tensor([
0.38899, 0.80478, 0.48836, 0.97381, 0.21567, 0.92312,
0.57508, 0.60835, 0.65467, 0.00168, 0.31567, 0.71345,
@@ -47,6 +52,6 @@ if __name__ == "__main__":
0.13449, 0.27367, 0.53036, 0.18962, 0.45623, 0.14523,
0.57672, 0.48364, 0.10863, 0.0571, 0.78934, 0.67545
]).reshape(1, 2, 4, 6)
# fmt: on
gen_batch_norm_test_result(input)