Add running mean and running var to batchnorm

This commit is contained in:
2024-08-25 19:05:10 +02:00
parent 1136ca452f
commit 9704d0d53e
8 changed files with 205 additions and 71 deletions

View File

@@ -5,7 +5,7 @@ from utils import print_cpp_vector
def gen_batch_norm_test_result(input):
batch_norm = torch.nn.BatchNorm2d(2, track_running_stats=False)
batch_norm = torch.nn.BatchNorm2d(2, track_running_stats=True)
weights = torch.Tensor([0.63508, 0.64903])
biases = torch.Tensor([0.25079, 0.66841])
@@ -13,7 +13,13 @@ def gen_batch_norm_test_result(input):
batch_norm.weight = torch.nn.Parameter(weights)
batch_norm.bias = torch.nn.Parameter(biases)
batch_norm.running_mean = torch.Tensor([0.5, 0.5])
batch_norm.running_var = torch.Tensor([1.0, 1.0])
batch_norm.eval()
output = batch_norm(input)
print_cpp_vector(output.flatten())