mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-11-06 09:44:28 +00:00
Add non square pooling and batch norm tests
This commit is contained in:
@@ -2,30 +2,51 @@ import torch
|
||||
|
||||
from utils import print_cpp_vector
|
||||
|
||||
batch_norm = torch.nn.BatchNorm2d(2, track_running_stats=False)
|
||||
def gen_batch_norm_test_result(input):
|
||||
|
||||
weights = torch.Tensor([0.63508, 0.64903])
|
||||
biases= torch.Tensor([0.25079, 0.66841])
|
||||
batch_norm = torch.nn.BatchNorm2d(2, track_running_stats=False)
|
||||
|
||||
batch_norm.weight = torch.nn.Parameter(weights)
|
||||
batch_norm.bias = torch.nn.Parameter(biases)
|
||||
weights = torch.Tensor([0.63508, 0.64903])
|
||||
biases= torch.Tensor([0.25079, 0.66841])
|
||||
|
||||
input = torch.Tensor([
|
||||
# Channel 0
|
||||
0.38899, 0.80478, 0.48836, 0.97381,
|
||||
0.57508, 0.60835, 0.65467, 0.00168,
|
||||
0.65869, 0.74235, 0.17928, 0.70349,
|
||||
0.15524, 0.38664, 0.23411, 0.7137,
|
||||
# Channel 1
|
||||
0.32473, 0.15698, 0.314, 0.60888,
|
||||
0.80268, 0.99766, 0.93694, 0.89237,
|
||||
0.13449, 0.27367, 0.53036, 0.18962,
|
||||
0.57672, 0.48364, 0.10863, 0.0571
|
||||
]).reshape(1, 2, 4, 4)
|
||||
batch_norm.weight = torch.nn.Parameter(weights)
|
||||
batch_norm.bias = torch.nn.Parameter(biases)
|
||||
|
||||
output = batch_norm(input)
|
||||
print_cpp_vector(output.flatten())
|
||||
output = batch_norm(input)
|
||||
print_cpp_vector(output.flatten())
|
||||
|
||||
print(batch_norm.running_mean)
|
||||
print(batch_norm.running_var)
|
||||
if __name__ == "__main__":
|
||||
|
||||
print("Generating test results...")
|
||||
print("Batch norm test:")
|
||||
|
||||
input = torch.Tensor([
|
||||
# Channel 0
|
||||
0.38899, 0.80478, 0.48836, 0.97381,
|
||||
0.57508, 0.60835, 0.65467, 0.00168,
|
||||
0.65869, 0.74235, 0.17928, 0.70349,
|
||||
0.15524, 0.38664, 0.23411, 0.7137,
|
||||
# Channel 1
|
||||
0.32473, 0.15698, 0.314, 0.60888,
|
||||
0.80268, 0.99766, 0.93694, 0.89237,
|
||||
0.13449, 0.27367, 0.53036, 0.18962,
|
||||
0.57672, 0.48364, 0.10863, 0.0571
|
||||
]).reshape(1, 2, 4, 4)
|
||||
|
||||
gen_batch_norm_test_result(input)
|
||||
|
||||
print("Batch norm test non square input:")
|
||||
|
||||
input = torch.Tensor([
|
||||
0.38899, 0.80478, 0.48836, 0.97381, 0.21567, 0.92312,
|
||||
0.57508, 0.60835, 0.65467, 0.00168, 0.31567, 0.71345,
|
||||
0.65869, 0.74235, 0.17928, 0.70349, 0.12856, 0.95645,
|
||||
0.15524, 0.38664, 0.23411, 0.7137, 0.26789, 0.83412,
|
||||
0.32473, 0.15698, 0.314, 0.60888, 0.23145, 0.78945,
|
||||
0.80268, 0.99766, 0.93694, 0.89237, 0.61234, 0.92314,
|
||||
0.13449, 0.27367, 0.53036, 0.18962, 0.45623, 0.14523,
|
||||
0.57672, 0.48364, 0.10863, 0.0571, 0.78934, 0.67545
|
||||
]).reshape(1, 2, 4, 6)
|
||||
|
||||
gen_batch_norm_test_result(input)
|
||||
|
||||
|
||||
@@ -14,6 +14,19 @@ def _get_pool_input():
|
||||
0.532, 0.819, 0.732, 0.850
|
||||
]).reshape(1, 2, 4, 4)
|
||||
|
||||
def _get_pool_input_non_square():
|
||||
return torch.Tensor([
|
||||
0.573, 0.619, 0.732, 0.055, 0.123, 0.234,
|
||||
0.243, 0.316, 0.573, 0.619, 0.456, 0.789,
|
||||
0.712, 0.055, 0.243, 0.316, 0.654, 0.987,
|
||||
0.573, 0.619, 0.742, 0.055, 0.321, 0.654,
|
||||
0.473, 0.919, 0.107, 0.073, 0.321, 0.654,
|
||||
0.073, 0.362, 0.973, 0.059, 0.654, 0.987,
|
||||
0.473, 0.455, 0.283, 0.416, 0.789, 0.123,
|
||||
0.532, 0.819, 0.732, 0.850, 0.987, 0.321
|
||||
]).reshape(1, 2, 4, 6)
|
||||
|
||||
|
||||
def gen_max_pool_test_result():
|
||||
input = _get_pool_input()
|
||||
|
||||
@@ -23,6 +36,33 @@ def gen_max_pool_test_result():
|
||||
print_cpp_vector(output)
|
||||
|
||||
|
||||
def gen_max_pool_non_square_input_test_result():
|
||||
input = _get_pool_input_non_square()
|
||||
|
||||
output = torch.nn.MaxPool2d(kernel_size=2, stride=2)(input)
|
||||
output = torch.flatten(output)
|
||||
|
||||
print_cpp_vector(output)
|
||||
|
||||
|
||||
def gen_max_non_square_pool_test_result():
|
||||
input = _get_pool_input()
|
||||
|
||||
output = torch.nn.MaxPool2d(kernel_size=(2, 3), stride=2)(input)
|
||||
output = torch.flatten(output)
|
||||
|
||||
print_cpp_vector(output)
|
||||
|
||||
|
||||
def gen_max_pool_non_square_stride_test_result():
|
||||
input = _get_pool_input()
|
||||
|
||||
output = torch.nn.MaxPool2d(kernel_size=2, stride=(1, 2))(input)
|
||||
output = torch.flatten(output)
|
||||
|
||||
print_cpp_vector(output)
|
||||
|
||||
|
||||
def gen_avg_pool_test_result():
|
||||
|
||||
input = _get_pool_input()
|
||||
@@ -33,9 +73,55 @@ def gen_avg_pool_test_result():
|
||||
print_cpp_vector(output)
|
||||
|
||||
|
||||
def gen_avg_pool_non_square_input_test_result():
|
||||
|
||||
input = _get_pool_input_non_square()
|
||||
|
||||
output = torch.nn.AvgPool2d(kernel_size=2, stride=2)(input)
|
||||
output = torch.flatten(output)
|
||||
|
||||
print_cpp_vector(output)
|
||||
|
||||
|
||||
def gen_avg_non_square_pool_test_result():
|
||||
|
||||
input = _get_pool_input()
|
||||
|
||||
output = torch.nn.AvgPool2d(kernel_size=(2, 3), stride=2)(input)
|
||||
output = torch.flatten(output)
|
||||
|
||||
print_cpp_vector(output)
|
||||
|
||||
|
||||
def gen_avg_pool_non_square_stride_test_result():
|
||||
|
||||
input = _get_pool_input()
|
||||
|
||||
output = torch.nn.AvgPool2d(kernel_size=2, stride=(1, 2))(input)
|
||||
output = torch.flatten(output)
|
||||
|
||||
print_cpp_vector(output)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Generating test results...")
|
||||
print("Max pool test:")
|
||||
gen_max_pool_test_result()
|
||||
print("Max pool non square input test:")
|
||||
gen_max_pool_non_square_input_test_result()
|
||||
print("Max non square pool test:")
|
||||
gen_max_non_square_pool_test_result()
|
||||
print("Max pool non square stride test:")
|
||||
gen_max_pool_non_square_stride_test_result()
|
||||
|
||||
print("--------------")
|
||||
|
||||
print("Avg pool test:")
|
||||
gen_avg_pool_test_result()
|
||||
print("Avg pool non square input test:")
|
||||
gen_avg_pool_non_square_input_test_result()
|
||||
print("Avg non square pool test:")
|
||||
gen_avg_non_square_pool_test_result()
|
||||
print("Avg pool non square stride test:")
|
||||
gen_avg_pool_non_square_stride_test_result()
|
||||
|
||||
Reference in New Issue
Block a user