Fix inception v3 tests

This commit is contained in:
2024-08-31 23:08:29 +02:00
parent bc9bff10cd
commit c8557fc0e4
9 changed files with 150 additions and 294 deletions

View File

@@ -2,6 +2,7 @@ import sys
import torch
from torchvision.models.inception import (
BasicConv2d,
InceptionA,
InceptionB,
InceptionC,
@@ -26,6 +27,7 @@ class InceptionBlockModel(torch.nn.Module):
x = self.inception_block(x)
x = torch.flatten(x)
x = self.fc(x)
# x = torch.nn.functional.tanh(x)
return x
@@ -37,6 +39,11 @@ def init_weights(m: torch.nn.Module):
torch.nn.init.uniform_(m.weight, -1)
torch.nn.init.uniform_(m.bias, 1)
if isinstance(m, torch.nn.BatchNorm2d):
# Initialize running_mean and running_var
m.running_mean.uniform_(-1, 1)
m.running_var.uniform_(0, 1) # Variance should be positive
@torch.no_grad()
def generate_module_test_data(m: torch.nn.Module, name: str):
@@ -64,17 +71,20 @@ def generate_module_test_data(m: torch.nn.Module, name: str):
if __name__ == "__main__":
m = InceptionA(3, 6)
generate_module_test_data(m, "inception_a")
# m = BasicConv2d(3, 32, kernel_size=3, stride=2, padding=0)
# generate_module_test_data(m, "basic_conv2d")
m = InceptionB(3)
generate_module_test_data(m, "inception_b")
# m = InceptionA(3, 6)
# generate_module_test_data(m, "inception_a")
m = InceptionC(3, 64)
generate_module_test_data(m, "inception_c")
# m = InceptionB(3)
# generate_module_test_data(m, "inception_b")
m = InceptionD(3)
generate_module_test_data(m, "inception_d")
# m = InceptionC(3, 64)
# generate_module_test_data(m, "inception_c")
# m = InceptionD(3)
# generate_module_test_data(m, "inception_d")
m = InceptionE(3)
generate_module_test_data(m, "inception_e")