Add running mean and running var to batchnorm

This commit is contained in:
2024-08-25 19:05:10 +02:00
parent 1136ca452f
commit 9704d0d53e
8 changed files with 205 additions and 71 deletions

View File

@@ -35,6 +35,19 @@ def export_model_weights(model: torch.nn.Module, filename):
tensor_data += tensor_bytes
# print(model.named_buffers)
# Add buffers (for running_mean and running_var)
for name, buf in model.named_buffers():
if "running_mean" not in name and "running_var" not in name:
continue
tensor_bytes = buf.type(torch.float32).detach().numpy().tobytes()
tensor_size = buf.numel()
header += f"{name},{tensor_size},{offset}\n"
offset += len(tensor_bytes)
tensor_data += tensor_bytes
f.seek(0)
f.write(struct.pack("H", version))
f.write(struct.pack("Q", len(header)))