mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-11-06 01:34:22 +00:00
Run black autoformatter
This commit is contained in:
@@ -2,20 +2,17 @@ import torch
|
||||
|
||||
from utils import print_cpp_vector
|
||||
|
||||
def _conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
inputs,
|
||||
weights):
|
||||
|
||||
conv2d = torch.nn.Conv2d(in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
bias=False)
|
||||
def _conv2d(in_channels, out_channels, kernel_size, stride, padding, inputs, weights):
|
||||
|
||||
conv2d = torch.nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
bias=False,
|
||||
)
|
||||
conv2d.weight = torch.nn.Parameter(weights)
|
||||
|
||||
output = conv2d(inputs)
|
||||
@@ -24,6 +21,7 @@ def _conv2d(in_channels,
|
||||
output = torch.flatten(output)
|
||||
return output
|
||||
|
||||
|
||||
def gen_convd_padded_test_result():
|
||||
|
||||
in_channels = 3
|
||||
@@ -32,7 +30,7 @@ def gen_convd_padded_test_result():
|
||||
stride = 1
|
||||
padding = 1
|
||||
|
||||
# Define input and kernel data as tensors
|
||||
# fmt: off
|
||||
inputs = torch.tensor([
|
||||
0.823, 0.217, 0.435, 0.981, 0.742,
|
||||
0.109, 0.518, 0.374, 0.681, 0.147,
|
||||
@@ -71,15 +69,12 @@ def gen_convd_padded_test_result():
|
||||
0.678, 0.011, 0.345,
|
||||
0.011, 0.345, 0.678
|
||||
], dtype=torch.float).reshape(2, 3, 3, 3)
|
||||
# fmt: on
|
||||
|
||||
output = _conv2d(
|
||||
in_channels, out_channels, kernel_size, stride, padding, inputs, weights
|
||||
)
|
||||
|
||||
output = _conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
inputs,
|
||||
weights)
|
||||
|
||||
print_cpp_vector(output)
|
||||
|
||||
|
||||
@@ -91,6 +86,7 @@ def gen_convd_strided_test_result():
|
||||
stride = 2
|
||||
padding = 3
|
||||
|
||||
# fmt: off
|
||||
input = torch.tensor([
|
||||
0.946, 0.879, 0.382, 0.542, 0.453,
|
||||
0.128, 0.860, 0.778, 0.049, 0.974,
|
||||
@@ -103,6 +99,7 @@ def gen_convd_strided_test_result():
|
||||
0.473, 0.303, 0.084, 0.785, 0.444,
|
||||
0.464, 0.413, 0.779, 0.298, 0.783
|
||||
], dtype=torch.float).reshape(1, 2, 5, 5)
|
||||
|
||||
weights = torch.tensor([
|
||||
0.744, 0.745, 0.641,
|
||||
0.164, 0.157, 0.127,
|
||||
@@ -117,15 +114,12 @@ def gen_convd_strided_test_result():
|
||||
0.236, 0.397, 0.739,
|
||||
0.939, 0.891, 0.006
|
||||
], dtype=torch.float).reshape(2, 2, 3, 3)
|
||||
# fmt: on
|
||||
|
||||
output = _conv2d(
|
||||
in_channels, out_channels, kernel_size, stride, padding, input, weights
|
||||
)
|
||||
|
||||
output = _conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
input,
|
||||
weights)
|
||||
|
||||
print_cpp_vector(output)
|
||||
|
||||
|
||||
@@ -137,6 +131,7 @@ def gen_convd_non_square_input_test_result():
|
||||
stride = 1
|
||||
padding = 0
|
||||
|
||||
# fmt: off
|
||||
input = torch.tensor([
|
||||
0.946, 0.879, 0.382, 0.542, 0.453, 0.128,
|
||||
0.128, 0.860, 0.778, 0.049, 0.974, 0.400,
|
||||
@@ -144,22 +139,19 @@ def gen_convd_non_square_input_test_result():
|
||||
0.078, 0.366, 0.396, 0.181, 0.246, 0.112,
|
||||
]).reshape(1, 1, 4, 6)
|
||||
|
||||
|
||||
weights = torch.tensor([
|
||||
0.744, 0.745,
|
||||
0.164, 0.157,
|
||||
]).reshape(1, 1, 2, 2)
|
||||
# fmt: on
|
||||
|
||||
output = _conv2d(
|
||||
in_channels, out_channels, kernel_size, stride, padding, input, weights
|
||||
)
|
||||
|
||||
output = _conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
input,
|
||||
weights)
|
||||
|
||||
print_cpp_vector(output)
|
||||
|
||||
|
||||
def gen_convd_non_square_kernel_test_result():
|
||||
|
||||
in_channels = 1
|
||||
@@ -168,6 +160,7 @@ def gen_convd_non_square_kernel_test_result():
|
||||
stride = 1
|
||||
padding = 0
|
||||
|
||||
# fmt: off
|
||||
input = torch.tensor([
|
||||
0.946, 0.879, 0.382, 0.542,
|
||||
0.128, 0.860, 0.778, 0.049,
|
||||
@@ -178,17 +171,15 @@ def gen_convd_non_square_kernel_test_result():
|
||||
weights = torch.tensor([
|
||||
0.744, 0.745, 0.164
|
||||
]).reshape(1, 1, 1, 3)
|
||||
# fmt: on
|
||||
|
||||
output = _conv2d(
|
||||
in_channels, out_channels, kernel_size, stride, padding, input, weights
|
||||
)
|
||||
|
||||
output = _conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
input,
|
||||
weights)
|
||||
|
||||
print_cpp_vector(output)
|
||||
|
||||
|
||||
def gen_convd_non_square_stride_test_result():
|
||||
|
||||
in_channels = 1
|
||||
@@ -197,6 +188,7 @@ def gen_convd_non_square_stride_test_result():
|
||||
stride = (1, 2)
|
||||
padding = 0
|
||||
|
||||
# fmt: off
|
||||
input = torch.tensor([
|
||||
0.946, 0.879, 0.382, 0.542,
|
||||
0.128, 0.860, 0.778, 0.049,
|
||||
@@ -208,17 +200,15 @@ def gen_convd_non_square_stride_test_result():
|
||||
0.144, 0.745,
|
||||
0.964, 0.164
|
||||
]).reshape(1, 1, 2, 2)
|
||||
# fmt: on
|
||||
|
||||
output = _conv2d(
|
||||
in_channels, out_channels, kernel_size, stride, padding, input, weights
|
||||
)
|
||||
|
||||
output = _conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
input,
|
||||
weights)
|
||||
|
||||
print_cpp_vector(output)
|
||||
|
||||
|
||||
def gen_convd_non_square_padding_test_result():
|
||||
|
||||
in_channels = 1
|
||||
@@ -227,6 +217,7 @@ def gen_convd_non_square_padding_test_result():
|
||||
stride = 1
|
||||
padding = (1, 2)
|
||||
|
||||
# fmt: off
|
||||
input = torch.tensor([
|
||||
0.946, 0.879, 0.382, 0.542,
|
||||
0.128, 0.860, 0.778, 0.049,
|
||||
@@ -238,15 +229,12 @@ def gen_convd_non_square_padding_test_result():
|
||||
0.144, 0.745,
|
||||
0.964, 0.164
|
||||
]).reshape(1, 1, 2, 2)
|
||||
# fmt: on
|
||||
|
||||
output = _conv2d(
|
||||
in_channels, out_channels, kernel_size, stride, padding, input, weights
|
||||
)
|
||||
|
||||
output = _conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
input,
|
||||
weights)
|
||||
|
||||
print_cpp_vector(output)
|
||||
|
||||
|
||||
@@ -263,4 +251,4 @@ if __name__ == "__main__":
|
||||
print("Non square stride convolution test:")
|
||||
gen_convd_non_square_stride_test_result()
|
||||
print("Non square padding convolution test:")
|
||||
gen_convd_non_square_padding_test_result()
|
||||
gen_convd_non_square_padding_test_result()
|
||||
|
||||
Reference in New Issue
Block a user