mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-11-06 01:34:22 +00:00
Generate inception A test inputs
This commit is contained in:
62
examples/inception_v3/tests/inception_a.py
Normal file
62
examples/inception_v3/tests/inception_a.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from torchvision.models.inception import InceptionA
|
||||
|
||||
sys.path.append("../../../tools")
|
||||
from utils import print_cpp_vector
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
@torch.no_grad()
|
||||
def init_weights(m):
|
||||
if isinstance(m, torch.nn.Conv2d):
|
||||
torch.nn.init.uniform_(m.weight)
|
||||
elif isinstance(m, torch.nn.BatchNorm2d):
|
||||
torch.nn.init.uniform_(m.weight)
|
||||
torch.nn.init.uniform_(m.bias)
|
||||
|
||||
with torch.no_grad():
|
||||
inception_a = InceptionA(3, 6)
|
||||
inception_a.apply(init_weights)
|
||||
|
||||
# branch1x1
|
||||
print_cpp_vector(torch.flatten(inception_a.branch1x1.conv.weight), "branch1x1_conv_weights")
|
||||
print_cpp_vector(torch.flatten(inception_a.branch1x1.bn.weight), "branch1x1_bn_weights")
|
||||
print_cpp_vector(torch.flatten(inception_a.branch1x1.bn.bias), "branch1x1_bn_bias")
|
||||
|
||||
# branch5x5
|
||||
print_cpp_vector(torch.flatten(inception_a.branch5x5_1.conv.weight), "branch5x5_1_conv_weights")
|
||||
print_cpp_vector(torch.flatten(inception_a.branch5x5_1.bn.weight), "branch5x5_1_bn_weights")
|
||||
print_cpp_vector(torch.flatten(inception_a.branch5x5_1.bn.bias), "branch5x5_1_bn_bias")
|
||||
|
||||
print_cpp_vector(torch.flatten(inception_a.branch5x5_2.conv.weight), "branch5x5_2_conv_weights")
|
||||
print_cpp_vector(torch.flatten(inception_a.branch5x5_2.bn.weight), "branch5x5_2_bn_weights")
|
||||
print_cpp_vector(torch.flatten(inception_a.branch5x5_2.bn.bias), "branch5x5_2_bn_bias")
|
||||
|
||||
# branch3x3dbl
|
||||
print_cpp_vector(torch.flatten(inception_a.branch3x3dbl_1.conv.weight), "branch3x3dbl_1_conv_weights")
|
||||
print_cpp_vector(torch.flatten(inception_a.branch3x3dbl_1.bn.weight), "branch3x3dbl_1_bn_weights")
|
||||
print_cpp_vector(torch.flatten(inception_a.branch3x3dbl_1.bn.bias), "branch3x3dbl_1_bn_bias")
|
||||
|
||||
print_cpp_vector(torch.flatten(inception_a.branch3x3dbl_2.conv.weight), "branch3x3dbl_2_conv_weights")
|
||||
print_cpp_vector(torch.flatten(inception_a.branch3x3dbl_2.bn.weight), "branch3x3dbl_2_bn_weights")
|
||||
print_cpp_vector(torch.flatten(inception_a.branch3x3dbl_2.bn.bias), "branch3x3dbl_2_bn_bias")
|
||||
|
||||
print_cpp_vector(torch.flatten(inception_a.branch3x3dbl_3.conv.weight), "branch3x3dbl_3_conv_weights")
|
||||
print_cpp_vector(torch.flatten(inception_a.branch3x3dbl_3.bn.weight), "branch3x3dbl_3_bn_weights")
|
||||
print_cpp_vector(torch.flatten(inception_a.branch3x3dbl_3.bn.bias), "branch3x3dbl_3_bn_bias")
|
||||
|
||||
# branchPool
|
||||
print_cpp_vector(torch.flatten(inception_a.branch_pool.conv.weight), "branchPool_2_conv_weights")
|
||||
print_cpp_vector(torch.flatten(inception_a.branch_pool.bn.weight), "branchPool_2_bn_weights")
|
||||
print_cpp_vector(torch.flatten(inception_a.branch_pool.bn.bias), "branchPool_2_bn_bias")
|
||||
|
||||
input_shape = (1, 3, 8, 8)
|
||||
input = torch.randn(input_shape)
|
||||
print_cpp_vector(torch.flatten(input), "input")
|
||||
|
||||
output = inception_a(input)
|
||||
output = torch.flatten(output)
|
||||
print_cpp_vector(output)
|
||||
|
||||
Reference in New Issue
Block a user