Run black autoformatter

This commit is contained in:
2024-05-30 13:18:51 +02:00
parent 2f3c34b8b5
commit fcd7d96126
7 changed files with 110 additions and 85 deletions

View File

@@ -2,20 +2,17 @@ import torch
from utils import print_cpp_vector
def _conv2d(in_channels,
out_channels,
kernel_size,
stride,
padding,
inputs,
weights):
conv2d = torch.nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=False)
def _conv2d(in_channels, out_channels, kernel_size, stride, padding, inputs, weights):
conv2d = torch.nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=False,
)
conv2d.weight = torch.nn.Parameter(weights)
output = conv2d(inputs)
@@ -24,6 +21,7 @@ def _conv2d(in_channels,
output = torch.flatten(output)
return output
def gen_convd_padded_test_result():
in_channels = 3
@@ -32,7 +30,7 @@ def gen_convd_padded_test_result():
stride = 1
padding = 1
# Define input and kernel data as tensors
# fmt: off
inputs = torch.tensor([
0.823, 0.217, 0.435, 0.981, 0.742,
0.109, 0.518, 0.374, 0.681, 0.147,
@@ -71,15 +69,12 @@ def gen_convd_padded_test_result():
0.678, 0.011, 0.345,
0.011, 0.345, 0.678
], dtype=torch.float).reshape(2, 3, 3, 3)
# fmt: on
output = _conv2d(
in_channels, out_channels, kernel_size, stride, padding, inputs, weights
)
output = _conv2d(in_channels,
out_channels,
kernel_size,
stride,
padding,
inputs,
weights)
print_cpp_vector(output)
@@ -91,6 +86,7 @@ def gen_convd_strided_test_result():
stride = 2
padding = 3
# fmt: off
input = torch.tensor([
0.946, 0.879, 0.382, 0.542, 0.453,
0.128, 0.860, 0.778, 0.049, 0.974,
@@ -103,6 +99,7 @@ def gen_convd_strided_test_result():
0.473, 0.303, 0.084, 0.785, 0.444,
0.464, 0.413, 0.779, 0.298, 0.783
], dtype=torch.float).reshape(1, 2, 5, 5)
weights = torch.tensor([
0.744, 0.745, 0.641,
0.164, 0.157, 0.127,
@@ -117,15 +114,12 @@ def gen_convd_strided_test_result():
0.236, 0.397, 0.739,
0.939, 0.891, 0.006
], dtype=torch.float).reshape(2, 2, 3, 3)
# fmt: on
output = _conv2d(
in_channels, out_channels, kernel_size, stride, padding, input, weights
)
output = _conv2d(in_channels,
out_channels,
kernel_size,
stride,
padding,
input,
weights)
print_cpp_vector(output)
@@ -137,6 +131,7 @@ def gen_convd_non_square_input_test_result():
stride = 1
padding = 0
# fmt: off
input = torch.tensor([
0.946, 0.879, 0.382, 0.542, 0.453, 0.128,
0.128, 0.860, 0.778, 0.049, 0.974, 0.400,
@@ -144,22 +139,19 @@ def gen_convd_non_square_input_test_result():
0.078, 0.366, 0.396, 0.181, 0.246, 0.112,
]).reshape(1, 1, 4, 6)
weights = torch.tensor([
0.744, 0.745,
0.164, 0.157,
]).reshape(1, 1, 2, 2)
# fmt: on
output = _conv2d(
in_channels, out_channels, kernel_size, stride, padding, input, weights
)
output = _conv2d(in_channels,
out_channels,
kernel_size,
stride,
padding,
input,
weights)
print_cpp_vector(output)
def gen_convd_non_square_kernel_test_result():
in_channels = 1
@@ -168,6 +160,7 @@ def gen_convd_non_square_kernel_test_result():
stride = 1
padding = 0
# fmt: off
input = torch.tensor([
0.946, 0.879, 0.382, 0.542,
0.128, 0.860, 0.778, 0.049,
@@ -178,17 +171,15 @@ def gen_convd_non_square_kernel_test_result():
weights = torch.tensor([
0.744, 0.745, 0.164
]).reshape(1, 1, 1, 3)
# fmt: on
output = _conv2d(
in_channels, out_channels, kernel_size, stride, padding, input, weights
)
output = _conv2d(in_channels,
out_channels,
kernel_size,
stride,
padding,
input,
weights)
print_cpp_vector(output)
def gen_convd_non_square_stride_test_result():
in_channels = 1
@@ -197,6 +188,7 @@ def gen_convd_non_square_stride_test_result():
stride = (1, 2)
padding = 0
# fmt: off
input = torch.tensor([
0.946, 0.879, 0.382, 0.542,
0.128, 0.860, 0.778, 0.049,
@@ -208,17 +200,15 @@ def gen_convd_non_square_stride_test_result():
0.144, 0.745,
0.964, 0.164
]).reshape(1, 1, 2, 2)
# fmt: on
output = _conv2d(
in_channels, out_channels, kernel_size, stride, padding, input, weights
)
output = _conv2d(in_channels,
out_channels,
kernel_size,
stride,
padding,
input,
weights)
print_cpp_vector(output)
def gen_convd_non_square_padding_test_result():
in_channels = 1
@@ -227,6 +217,7 @@ def gen_convd_non_square_padding_test_result():
stride = 1
padding = (1, 2)
# fmt: off
input = torch.tensor([
0.946, 0.879, 0.382, 0.542,
0.128, 0.860, 0.778, 0.049,
@@ -238,15 +229,12 @@ def gen_convd_non_square_padding_test_result():
0.144, 0.745,
0.964, 0.164
]).reshape(1, 1, 2, 2)
# fmt: on
output = _conv2d(
in_channels, out_channels, kernel_size, stride, padding, input, weights
)
output = _conv2d(in_channels,
out_channels,
kernel_size,
stride,
padding,
input,
weights)
print_cpp_vector(output)
@@ -263,4 +251,4 @@ if __name__ == "__main__":
print("Non square stride convolution test:")
gen_convd_non_square_stride_test_result()
print("Non square padding convolution test:")
gen_convd_non_square_padding_test_result()
gen_convd_non_square_padding_test_result()