mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-11-06 01:34:22 +00:00
Add running mean and running var to batchnorm
This commit is contained in:
@@ -35,6 +35,19 @@ def export_model_weights(model: torch.nn.Module, filename):
|
||||
|
||||
tensor_data += tensor_bytes
|
||||
|
||||
# print(model.named_buffers)
|
||||
|
||||
# Add buffers (for running_mean and running_var)
|
||||
for name, buf in model.named_buffers():
|
||||
if "running_mean" not in name and "running_var" not in name:
|
||||
continue
|
||||
|
||||
tensor_bytes = buf.type(torch.float32).detach().numpy().tobytes()
|
||||
tensor_size = buf.numel()
|
||||
header += f"{name},{tensor_size},{offset}\n"
|
||||
offset += len(tensor_bytes)
|
||||
tensor_data += tensor_bytes
|
||||
|
||||
f.seek(0)
|
||||
f.write(struct.pack("H", version))
|
||||
f.write(struct.pack("Q", len(header)))
|
||||
|
||||
Reference in New Issue
Block a user