mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-12-23 06:44:24 +00:00
Add running mean and running var to batchnorm
This commit is contained in:
@@ -5,7 +5,7 @@ from utils import print_cpp_vector
|
||||
|
||||
def gen_batch_norm_test_result(input):
|
||||
|
||||
batch_norm = torch.nn.BatchNorm2d(2, track_running_stats=False)
|
||||
batch_norm = torch.nn.BatchNorm2d(2, track_running_stats=True)
|
||||
|
||||
weights = torch.Tensor([0.63508, 0.64903])
|
||||
biases = torch.Tensor([0.25079, 0.66841])
|
||||
@@ -13,7 +13,13 @@ def gen_batch_norm_test_result(input):
|
||||
batch_norm.weight = torch.nn.Parameter(weights)
|
||||
batch_norm.bias = torch.nn.Parameter(biases)
|
||||
|
||||
batch_norm.running_mean = torch.Tensor([0.5, 0.5])
|
||||
batch_norm.running_var = torch.Tensor([1.0, 1.0])
|
||||
|
||||
batch_norm.eval()
|
||||
|
||||
output = batch_norm(input)
|
||||
|
||||
print_cpp_vector(output.flatten())
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user