mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-11-07 02:04:26 +00:00
Fix inception v3 tests
This commit is contained in:
@@ -2,6 +2,7 @@ import sys
|
||||
|
||||
import torch
|
||||
from torchvision.models.inception import (
|
||||
BasicConv2d,
|
||||
InceptionA,
|
||||
InceptionB,
|
||||
InceptionC,
|
||||
@@ -26,6 +27,7 @@ class InceptionBlockModel(torch.nn.Module):
|
||||
x = self.inception_block(x)
|
||||
x = torch.flatten(x)
|
||||
x = self.fc(x)
|
||||
# x = torch.nn.functional.tanh(x)
|
||||
return x
|
||||
|
||||
|
||||
@@ -37,6 +39,11 @@ def init_weights(m: torch.nn.Module):
|
||||
torch.nn.init.uniform_(m.weight, -1)
|
||||
torch.nn.init.uniform_(m.bias, 1)
|
||||
|
||||
if isinstance(m, torch.nn.BatchNorm2d):
|
||||
# Initialize running_mean and running_var
|
||||
m.running_mean.uniform_(-1, 1)
|
||||
m.running_var.uniform_(0, 1) # Variance should be positive
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_module_test_data(m: torch.nn.Module, name: str):
|
||||
@@ -64,17 +71,20 @@ def generate_module_test_data(m: torch.nn.Module, name: str):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
m = InceptionA(3, 6)
|
||||
generate_module_test_data(m, "inception_a")
|
||||
# m = BasicConv2d(3, 32, kernel_size=3, stride=2, padding=0)
|
||||
# generate_module_test_data(m, "basic_conv2d")
|
||||
|
||||
m = InceptionB(3)
|
||||
generate_module_test_data(m, "inception_b")
|
||||
# m = InceptionA(3, 6)
|
||||
# generate_module_test_data(m, "inception_a")
|
||||
|
||||
m = InceptionC(3, 64)
|
||||
generate_module_test_data(m, "inception_c")
|
||||
# m = InceptionB(3)
|
||||
# generate_module_test_data(m, "inception_b")
|
||||
|
||||
m = InceptionD(3)
|
||||
generate_module_test_data(m, "inception_d")
|
||||
# m = InceptionC(3, 64)
|
||||
# generate_module_test_data(m, "inception_c")
|
||||
|
||||
# m = InceptionD(3)
|
||||
# generate_module_test_data(m, "inception_d")
|
||||
|
||||
m = InceptionE(3)
|
||||
generate_module_test_data(m, "inception_e")
|
||||
|
||||
Reference in New Issue
Block a user