mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-11-06 17:54:27 +00:00
Rework inception block tests
This commit is contained in:
83
examples/inception_v3/tests/inception_blocks.py
Normal file
83
examples/inception_v3/tests/inception_blocks.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from torchvision.models.inception import (
|
||||
InceptionA,
|
||||
InceptionB,
|
||||
InceptionC,
|
||||
InceptionD,
|
||||
InceptionE
|
||||
)
|
||||
|
||||
sys.path.append("../../../tools")
|
||||
from utils import print_cpp_vector, export_model_weights
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
output_size = 50
|
||||
|
||||
class InceptionBlockModel(torch.nn.Module):
|
||||
def __init__(self, inception_block: torch.nn.Module, linear_in: int, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.inception_block = inception_block
|
||||
self.fc = torch.nn.Linear(linear_in, output_size)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.inception_block(x)
|
||||
x = torch.flatten(x)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def init_weights(m: torch.nn.Module):
|
||||
if isinstance(m, torch.nn.Conv2d):
|
||||
torch.nn.init.uniform_(m.weight, -1, 1)
|
||||
elif isinstance(m, torch.nn.BatchNorm2d) or isinstance(m, torch.nn.Linear):
|
||||
torch.nn.init.uniform_(m.weight, -1)
|
||||
torch.nn.init.uniform_(m.bias, 1)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_module_test_data(m: torch.nn.Module, name: str):
|
||||
|
||||
print(name)
|
||||
|
||||
input_shape = (1, 3, 4, 4)
|
||||
input = torch.randn(input_shape)
|
||||
print_cpp_vector(torch.flatten(input), "input")
|
||||
|
||||
m.eval()
|
||||
inception_out = m(input)
|
||||
linear_in = torch.flatten(inception_out).size(0)
|
||||
|
||||
inception_block = InceptionBlockModel(m, linear_in)
|
||||
inception_block.apply(init_weights)
|
||||
|
||||
export_model_weights(inception_block, f"resources/{name}.bin")
|
||||
|
||||
inception_block.eval()
|
||||
output = inception_block(input)
|
||||
print_cpp_vector(torch.flatten(output), "expected")
|
||||
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
m = InceptionA(3, 6)
|
||||
generate_module_test_data(m, "inception_a")
|
||||
|
||||
m = InceptionB(3)
|
||||
generate_module_test_data(m, "inception_b")
|
||||
|
||||
m = InceptionC(3, 64)
|
||||
generate_module_test_data(m, "inception_c")
|
||||
|
||||
m = InceptionD(3)
|
||||
generate_module_test_data(m, "inception_d")
|
||||
|
||||
m = InceptionE(3)
|
||||
generate_module_test_data(m, "inception_e")
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user