mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-11-06 01:34:22 +00:00
Test inception module B
This commit is contained in:
43
examples/inception_v3/tests/inception_b.py
Normal file
43
examples/inception_v3/tests/inception_b.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
import sys
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torchvision.models.inception import InceptionB
|
||||||
|
|
||||||
|
sys.path.append("../../../tools")
|
||||||
|
from utils import print_cpp_vector
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def init_weights(m):
|
||||||
|
if isinstance(m, torch.nn.Conv2d):
|
||||||
|
torch.nn.init.uniform_(m.weight)
|
||||||
|
elif isinstance(m, torch.nn.BatchNorm2d):
|
||||||
|
torch.nn.init.uniform_(m.weight)
|
||||||
|
torch.nn.init.uniform_(m.bias)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
inception_b = InceptionB(3)
|
||||||
|
inception_b.apply(init_weights)
|
||||||
|
|
||||||
|
# branch3x3
|
||||||
|
print_cpp_vector(torch.flatten(inception_b.branch3x3.conv.weight), "branch3x3_conv_weights")
|
||||||
|
print_cpp_vector(torch.flatten(inception_b.branch3x3.bn.weight), "branch3x3_bn_weights")
|
||||||
|
print_cpp_vector(torch.flatten(inception_b.branch3x3.bn.bias), "branch3x3_bn_bias")
|
||||||
|
|
||||||
|
# branch3x3dbl
|
||||||
|
print_cpp_vector(torch.flatten(inception_b.branch3x3dbl_1.conv.weight), "branch3x3dbl_1_conv_weights")
|
||||||
|
print_cpp_vector(torch.flatten(inception_b.branch3x3dbl_1.bn.weight), "branch3x3dbl_1_bn_weights")
|
||||||
|
print_cpp_vector(torch.flatten(inception_b.branch3x3dbl_1.bn.bias), "branch3x3dbl_1_bn_bias")
|
||||||
|
|
||||||
|
print_cpp_vector(torch.flatten(inception_b.branch3x3dbl_2.conv.weight), "branch3x3dbl_2_conv_weights")
|
||||||
|
print_cpp_vector(torch.flatten(inception_b.branch3x3dbl_2.bn.weight), "branch3x3dbl_2_bn_weights")
|
||||||
|
print_cpp_vector(torch.flatten(inception_b.branch3x3dbl_2.bn.bias), "branch3x3dbl_2_bn_bias")
|
||||||
|
|
||||||
|
input_shape = (1, 3, 8, 8)
|
||||||
|
input = torch.randn(input_shape)
|
||||||
|
print_cpp_vector(torch.flatten(input), "input")
|
||||||
|
|
||||||
|
output = inception_b(input)
|
||||||
|
output = torch.flatten(output)
|
||||||
|
print_cpp_vector(output, "expected")
|
||||||
@@ -94,8 +94,6 @@ class InceptionATest : public ::testing::Test {
|
|||||||
inception_a =
|
inception_a =
|
||||||
new InceptionA(inputShape, inputChannels, poolFeatures, prefix);
|
new InceptionA(inputShape, inputChannels, poolFeatures, prefix);
|
||||||
|
|
||||||
CUDANet::Layers::Conv2d *conv;
|
|
||||||
|
|
||||||
// Set up layer weights and bias
|
// Set up layer weights and bias
|
||||||
// Branch 1x1
|
// Branch 1x1
|
||||||
setBasicConv2dWeights(branch1x1_conv_weights, branch1x1_bn_weights,
|
setBasicConv2dWeights(branch1x1_conv_weights, branch1x1_bn_weights,
|
||||||
|
|||||||
10382
examples/inception_v3/tests/test_inception_b.cpp
Normal file
10382
examples/inception_v3/tests/test_inception_b.cpp
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user