Test softmax

This commit is contained in:
2024-03-17 19:08:16 +01:00
parent 42d646750b
commit cbdb4e7707
5 changed files with 49 additions and 4 deletions

View File

@@ -110,10 +110,20 @@ def gen_strided_test_result():
output = conv2d(in_channels, out_channels, kernel_size, stride, padding, input, weights)
print_cpp_vector(output)
def gen_softmax_test_result():
input = torch.tensor([
0.573, 0.619, 0.732, 0.055, 0.243
])
output = torch.nn.Softmax(dim=0)(input)
print_cpp_vector(output)
if __name__ == "__main__":
print("Generating test results...")
print("Padded convolution test:")
gen_padded_test_result()
print("Strided convolution test:")
gen_strided_test_result()
gen_strided_test_result()
print("Softmax test:")
gen_softmax_test_result()