diff --git a/examples/inception_v3/tests/inception_a.py b/examples/inception_v3/tests/inception_a.py new file mode 100644 index 0000000..e0f4fd6 --- /dev/null +++ b/examples/inception_v3/tests/inception_a.py @@ -0,0 +1,62 @@ +import sys + +import torch +from torchvision.models.inception import InceptionA + +sys.path.append("../../../tools") +from utils import print_cpp_vector + +torch.manual_seed(0) + +@torch.no_grad() +def init_weights(m): + if isinstance(m, torch.nn.Conv2d): + torch.nn.init.uniform_(m.weight) + elif isinstance(m, torch.nn.BatchNorm2d): + torch.nn.init.uniform_(m.weight) + torch.nn.init.uniform_(m.bias) + +with torch.no_grad(): + inception_a = InceptionA(3, 6) + inception_a.apply(init_weights) + + # branch1x1 + print_cpp_vector(torch.flatten(inception_a.branch1x1.conv.weight), "branch1x1_conv_weights") + print_cpp_vector(torch.flatten(inception_a.branch1x1.bn.weight), "branch1x1_bn_weights") + print_cpp_vector(torch.flatten(inception_a.branch1x1.bn.bias), "branch1x1_bn_bias") + + # branch5x5 + print_cpp_vector(torch.flatten(inception_a.branch5x5_1.conv.weight), "branch5x5_1_conv_weights") + print_cpp_vector(torch.flatten(inception_a.branch5x5_1.bn.weight), "branch5x5_1_bn_weights") + print_cpp_vector(torch.flatten(inception_a.branch5x5_1.bn.bias), "branch5x5_1_bn_bias") + + print_cpp_vector(torch.flatten(inception_a.branch5x5_2.conv.weight), "branch5x5_2_conv_weights") + print_cpp_vector(torch.flatten(inception_a.branch5x5_2.bn.weight), "branch5x5_2_bn_weights") + print_cpp_vector(torch.flatten(inception_a.branch5x5_2.bn.bias), "branch5x5_2_bn_bias") + + # branch3x3dbl + print_cpp_vector(torch.flatten(inception_a.branch3x3dbl_1.conv.weight), "branch3x3dbl_1_conv_weights") + print_cpp_vector(torch.flatten(inception_a.branch3x3dbl_1.bn.weight), "branch3x3dbl_1_bn_weights") + print_cpp_vector(torch.flatten(inception_a.branch3x3dbl_1.bn.bias), "branch3x3dbl_1_bn_bias") + + print_cpp_vector(torch.flatten(inception_a.branch3x3dbl_2.conv.weight), "branch3x3dbl_2_conv_weights") + print_cpp_vector(torch.flatten(inception_a.branch3x3dbl_2.bn.weight), "branch3x3dbl_2_bn_weights") + print_cpp_vector(torch.flatten(inception_a.branch3x3dbl_2.bn.bias), "branch3x3dbl_2_bn_bias") + + print_cpp_vector(torch.flatten(inception_a.branch3x3dbl_3.conv.weight), "branch3x3dbl_3_conv_weights") + print_cpp_vector(torch.flatten(inception_a.branch3x3dbl_3.bn.weight), "branch3x3dbl_3_bn_weights") + print_cpp_vector(torch.flatten(inception_a.branch3x3dbl_3.bn.bias), "branch3x3dbl_3_bn_bias") + + # branchPool + print_cpp_vector(torch.flatten(inception_a.branch_pool.conv.weight), "branchPool_2_conv_weights") + print_cpp_vector(torch.flatten(inception_a.branch_pool.bn.weight), "branchPool_2_bn_weights") + print_cpp_vector(torch.flatten(inception_a.branch_pool.bn.bias), "branchPool_2_bn_bias") + + input_shape = (1, 3, 8, 8) + input = torch.randn(input_shape) + print_cpp_vector(torch.flatten(input), "input") + + output = inception_a(input) + output = torch.flatten(output) + print_cpp_vector(output) +