mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-11-05 17:34:21 +00:00
73 lines
3.7 KiB
Python
73 lines
3.7 KiB
Python
import sys
|
|
|
|
import torch
|
|
from torchvision.models.inception import InceptionC
|
|
|
|
sys.path.append("../../../tools")
|
|
from utils import print_cpp_vector
|
|
|
|
torch.manual_seed(0)
|
|
|
|
@torch.no_grad()
|
|
def init_weights(m):
|
|
if isinstance(m, torch.nn.Conv2d):
|
|
torch.nn.init.uniform_(m.weight)
|
|
elif isinstance(m, torch.nn.BatchNorm2d):
|
|
torch.nn.init.uniform_(m.weight)
|
|
torch.nn.init.uniform_(m.bias)
|
|
|
|
with torch.no_grad():
|
|
inception_c = InceptionC(3, 64)
|
|
inception_c.apply(init_weights)
|
|
|
|
# branch1x1
|
|
print_cpp_vector(torch.flatten(inception_c.branch1x1.conv.weight), "branch1x1_conv_weights")
|
|
print_cpp_vector(torch.flatten(inception_c.branch1x1.bn.weight), "branch1x1_bn_weights")
|
|
print_cpp_vector(torch.flatten(inception_c.branch1x1.bn.bias), "branch1x1_bn_bias")
|
|
|
|
# branch7x7
|
|
print_cpp_vector(torch.flatten(inception_c.branch7x7_1.conv.weight), "branch7x7_1_conv_weights")
|
|
print_cpp_vector(torch.flatten(inception_c.branch7x7_1.bn.weight), "branch7x7_1_bn_weights")
|
|
print_cpp_vector(torch.flatten(inception_c.branch7x7_1.bn.bias), "branch7x7_1_bn_bias")
|
|
|
|
print_cpp_vector(torch.flatten(inception_c.branch7x7_2.conv.weight), "branch7x7_2_conv_weights")
|
|
print_cpp_vector(torch.flatten(inception_c.branch7x7_2.bn.weight), "branch7x7_2_bn_weights")
|
|
print_cpp_vector(torch.flatten(inception_c.branch7x7_2.bn.bias), "branch7x7_2_bn_bias")
|
|
|
|
print_cpp_vector(torch.flatten(inception_c.branch7x7_3.conv.weight), "branch7x7_3_conv_weights")
|
|
print_cpp_vector(torch.flatten(inception_c.branch7x7_3.bn.weight), "branch7x7_3_bn_weights")
|
|
print_cpp_vector(torch.flatten(inception_c.branch7x7_3.bn.bias), "branch7x7_3_bn_bias")
|
|
|
|
# branch7x7dbl
|
|
print_cpp_vector(torch.flatten(inception_c.branch7x7dbl_1.conv.weight), "branch7x7dbl_1_conv_weights")
|
|
print_cpp_vector(torch.flatten(inception_c.branch7x7dbl_1.bn.weight), "branch7x7dbl_1_bn_weights")
|
|
print_cpp_vector(torch.flatten(inception_c.branch7x7dbl_1.bn.bias), "branch7x7dbl_1_bn_bias")
|
|
|
|
print_cpp_vector(torch.flatten(inception_c.branch7x7dbl_2.conv.weight), "branch7x7dbl_2_conv_weights")
|
|
print_cpp_vector(torch.flatten(inception_c.branch7x7dbl_2.bn.weight), "branch7x7dbl_2_bn_weights")
|
|
print_cpp_vector(torch.flatten(inception_c.branch7x7dbl_2.bn.bias), "branch7x7dbl_2_bn_bias")
|
|
|
|
print_cpp_vector(torch.flatten(inception_c.branch7x7dbl_3.conv.weight), "branch7x7dbl_3_conv_weights")
|
|
print_cpp_vector(torch.flatten(inception_c.branch7x7dbl_3.bn.weight), "branch7x7dbl_3_bn_weights")
|
|
print_cpp_vector(torch.flatten(inception_c.branch7x7dbl_3.bn.bias), "branch7x7dbl_3_bn_bias")
|
|
|
|
print_cpp_vector(torch.flatten(inception_c.branch7x7dbl_4.conv.weight), "branch7x7dbl_4_conv_weights")
|
|
print_cpp_vector(torch.flatten(inception_c.branch7x7dbl_4.bn.weight), "branch7x7dbl_4_bn_weights")
|
|
print_cpp_vector(torch.flatten(inception_c.branch7x7dbl_4.bn.bias), "branch7x7dbl_4_bn_bias")
|
|
|
|
print_cpp_vector(torch.flatten(inception_c.branch7x7dbl_5.conv.weight), "branch7x7dbl_5_conv_weights")
|
|
print_cpp_vector(torch.flatten(inception_c.branch7x7dbl_5.bn.weight), "branch7x7dbl_5_bn_weights")
|
|
print_cpp_vector(torch.flatten(inception_c.branch7x7dbl_5.bn.bias), "branch7x7dbl_5_bn_bias")
|
|
|
|
# branch_pool
|
|
print_cpp_vector(torch.flatten(inception_c.branch_pool.conv.weight), "branchPool_2_conv_weights")
|
|
print_cpp_vector(torch.flatten(inception_c.branch_pool.bn.weight), "branchPool_2_bn_weights")
|
|
print_cpp_vector(torch.flatten(inception_c.branch_pool.bn.bias), "branchPool_2_bn_bias")
|
|
|
|
input_shape = (1, 3, 8, 8)
|
|
input = torch.randn(input_shape)
|
|
print_cpp_vector(torch.flatten(input), "input")
|
|
|
|
output = inception_c(input)
|
|
output = torch.flatten(output)
|
|
print_cpp_vector(output, "expected") |