Cleanup and refactor

This commit is contained in:
2024-04-11 22:52:41 +02:00
parent 4b9d123e94
commit 18522c2dea
9 changed files with 81 additions and 72 deletions

44
tools/dense_test.py Normal file
View File

@@ -0,0 +1,44 @@
import torch
from utils import print_cpp_vector
def gen_dense_softmax_test():
input = torch.tensor([
0.1, 0.2, 0.3, 0.4, 0.5
])
weights = torch.tensor([
0.5, 0.1, 0.1, 0.4, 0.2,
0.4, 0.3, 0.9, 0.0, 0.8,
0.8, 0.4, 0.6, 0.2, 0.0,
0.1, 0.7, 0.3, 1.0, 0.1
]).reshape(4, 5)
biases = torch.tensor([
0.1, 0.2, 0.3, 0.4
])
dense = torch.nn.Linear(5, 4)
dense.weight = torch.nn.Parameter(weights)
dense.bias = torch.nn.Parameter(biases)
output = dense(input)
print_cpp_vector(output)
# Manual softmax
softmax_exp = torch.exp(output)
print(softmax_exp)
softmax_sum = torch.sum(softmax_exp, dim=0)
print(softmax_sum)
souftmax_out = softmax_exp / softmax_sum
print(souftmax_out)
softmax = torch.nn.Softmax(dim=0)(output)
print_cpp_vector(softmax)
if __name__ == "__main__":
gen_dense_softmax_test()