mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-11-08 02:34:24 +00:00
Add running mean and running var to batchnorm
This commit is contained in:
@@ -5,7 +5,7 @@ from utils import print_cpp_vector
|
||||
|
||||
def gen_batch_norm_test_result(input):
|
||||
|
||||
batch_norm = torch.nn.BatchNorm2d(2, track_running_stats=False)
|
||||
batch_norm = torch.nn.BatchNorm2d(2, track_running_stats=True)
|
||||
|
||||
weights = torch.Tensor([0.63508, 0.64903])
|
||||
biases = torch.Tensor([0.25079, 0.66841])
|
||||
@@ -13,7 +13,13 @@ def gen_batch_norm_test_result(input):
|
||||
batch_norm.weight = torch.nn.Parameter(weights)
|
||||
batch_norm.bias = torch.nn.Parameter(biases)
|
||||
|
||||
batch_norm.running_mean = torch.Tensor([0.5, 0.5])
|
||||
batch_norm.running_var = torch.Tensor([1.0, 1.0])
|
||||
|
||||
batch_norm.eval()
|
||||
|
||||
output = batch_norm(input)
|
||||
|
||||
print_cpp_vector(output.flatten())
|
||||
|
||||
|
||||
|
||||
@@ -35,6 +35,19 @@ def export_model_weights(model: torch.nn.Module, filename):
|
||||
|
||||
tensor_data += tensor_bytes
|
||||
|
||||
# print(model.named_buffers)
|
||||
|
||||
# Add buffers (for running_mean and running_var)
|
||||
for name, buf in model.named_buffers():
|
||||
if "running_mean" not in name and "running_var" not in name:
|
||||
continue
|
||||
|
||||
tensor_bytes = buf.type(torch.float32).detach().numpy().tobytes()
|
||||
tensor_size = buf.numel()
|
||||
header += f"{name},{tensor_size},{offset}\n"
|
||||
offset += len(tensor_bytes)
|
||||
tensor_data += tensor_bytes
|
||||
|
||||
f.seek(0)
|
||||
f.write(struct.pack("H", version))
|
||||
f.write(struct.pack("Q", len(header)))
|
||||
|
||||
Reference in New Issue
Block a user