Add running mean and running var to batchnorm

This commit is contained in:
2024-08-25 19:05:10 +02:00
parent 1136ca452f
commit 9704d0d53e
8 changed files with 205 additions and 71 deletions

View File

@@ -5,7 +5,7 @@ from utils import print_cpp_vector
def gen_batch_norm_test_result(input):
batch_norm = torch.nn.BatchNorm2d(2, track_running_stats=False)
batch_norm = torch.nn.BatchNorm2d(2, track_running_stats=True)
weights = torch.Tensor([0.63508, 0.64903])
biases = torch.Tensor([0.25079, 0.66841])
@@ -13,7 +13,13 @@ def gen_batch_norm_test_result(input):
batch_norm.weight = torch.nn.Parameter(weights)
batch_norm.bias = torch.nn.Parameter(biases)
batch_norm.running_mean = torch.Tensor([0.5, 0.5])
batch_norm.running_var = torch.Tensor([1.0, 1.0])
batch_norm.eval()
output = batch_norm(input)
print_cpp_vector(output.flatten())

View File

@@ -35,6 +35,19 @@ def export_model_weights(model: torch.nn.Module, filename):
tensor_data += tensor_bytes
# print(model.named_buffers)
# Add buffers (for running_mean and running_var)
for name, buf in model.named_buffers():
if "running_mean" not in name and "running_var" not in name:
continue
tensor_bytes = buf.type(torch.float32).detach().numpy().tobytes()
tensor_size = buf.numel()
header += f"{name},{tensor_size},{offset}\n"
offset += len(tensor_bytes)
tensor_data += tensor_bytes
f.seek(0)
f.write(struct.pack("H", version))
f.write(struct.pack("Q", len(header)))