Rework inception block tests

This commit is contained in:
2024-06-05 21:58:33 +02:00
parent a54ffa8b20
commit 1136ca452f
12 changed files with 416 additions and 104738 deletions

View File

@@ -0,0 +1,83 @@
import sys
import torch
from torchvision.models.inception import (
InceptionA,
InceptionB,
InceptionC,
InceptionD,
InceptionE
)
sys.path.append("../../../tools")
from utils import print_cpp_vector, export_model_weights
torch.manual_seed(0)
output_size = 50
class InceptionBlockModel(torch.nn.Module):
def __init__(self, inception_block: torch.nn.Module, linear_in: int, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.inception_block = inception_block
self.fc = torch.nn.Linear(linear_in, output_size)
def forward(self, x):
x = self.inception_block(x)
x = torch.flatten(x)
x = self.fc(x)
return x
@torch.no_grad()
def init_weights(m: torch.nn.Module):
if isinstance(m, torch.nn.Conv2d):
torch.nn.init.uniform_(m.weight, -1, 1)
elif isinstance(m, torch.nn.BatchNorm2d) or isinstance(m, torch.nn.Linear):
torch.nn.init.uniform_(m.weight, -1)
torch.nn.init.uniform_(m.bias, 1)
@torch.no_grad()
def generate_module_test_data(m: torch.nn.Module, name: str):
print(name)
input_shape = (1, 3, 4, 4)
input = torch.randn(input_shape)
print_cpp_vector(torch.flatten(input), "input")
m.eval()
inception_out = m(input)
linear_in = torch.flatten(inception_out).size(0)
inception_block = InceptionBlockModel(m, linear_in)
inception_block.apply(init_weights)
export_model_weights(inception_block, f"resources/{name}.bin")
inception_block.eval()
output = inception_block(input)
print_cpp_vector(torch.flatten(output), "expected")
print()
if __name__ == "__main__":
m = InceptionA(3, 6)
generate_module_test_data(m, "inception_a")
m = InceptionB(3)
generate_module_test_data(m, "inception_b")
m = InceptionC(3, 64)
generate_module_test_data(m, "inception_c")
m = InceptionD(3)
generate_module_test_data(m, "inception_d")
m = InceptionE(3)
generate_module_test_data(m, "inception_e")