mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-11-06 09:44:28 +00:00
94 lines
2.4 KiB
Python
94 lines
2.4 KiB
Python
import sys
|
|
|
|
import torch
|
|
from torchvision.models.inception import (
|
|
BasicConv2d,
|
|
InceptionA,
|
|
InceptionB,
|
|
InceptionC,
|
|
InceptionD,
|
|
InceptionE
|
|
)
|
|
|
|
sys.path.append("../../../tools")
|
|
from utils import print_cpp_vector, export_model_weights
|
|
|
|
torch.manual_seed(0)
|
|
|
|
output_size = 50
|
|
|
|
class InceptionBlockModel(torch.nn.Module):
|
|
def __init__(self, inception_block: torch.nn.Module, linear_in: int, *args, **kwargs) -> None:
|
|
super().__init__(*args, **kwargs)
|
|
self.inception_block = inception_block
|
|
self.fc = torch.nn.Linear(linear_in, output_size)
|
|
|
|
def forward(self, x):
|
|
x = self.inception_block(x)
|
|
x = torch.flatten(x)
|
|
x = self.fc(x)
|
|
# x = torch.nn.functional.tanh(x)
|
|
return x
|
|
|
|
|
|
@torch.no_grad()
|
|
def init_weights(m: torch.nn.Module):
|
|
if isinstance(m, torch.nn.Conv2d):
|
|
torch.nn.init.uniform_(m.weight, -1, 1)
|
|
elif isinstance(m, torch.nn.BatchNorm2d) or isinstance(m, torch.nn.Linear):
|
|
torch.nn.init.uniform_(m.weight, -1)
|
|
torch.nn.init.uniform_(m.bias, 1)
|
|
|
|
if isinstance(m, torch.nn.BatchNorm2d):
|
|
# Initialize running_mean and running_var
|
|
m.running_mean.uniform_(-1, 1)
|
|
m.running_var.uniform_(0, 1) # Variance should be positive
|
|
|
|
|
|
@torch.no_grad()
|
|
def generate_module_test_data(m: torch.nn.Module, name: str):
|
|
|
|
print(name)
|
|
|
|
input_shape = (1, 3, 4, 4)
|
|
input = torch.randn(input_shape)
|
|
print_cpp_vector(torch.flatten(input), "input")
|
|
|
|
m.eval()
|
|
inception_out = m(input)
|
|
linear_in = torch.flatten(inception_out).size(0)
|
|
|
|
inception_block = InceptionBlockModel(m, linear_in)
|
|
inception_block.apply(init_weights)
|
|
|
|
export_model_weights(inception_block, f"resources/{name}.bin")
|
|
|
|
inception_block.eval()
|
|
output = inception_block(input)
|
|
print_cpp_vector(torch.flatten(output), "expected")
|
|
|
|
print()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# m = BasicConv2d(3, 32, kernel_size=3, stride=2, padding=0)
|
|
# generate_module_test_data(m, "basic_conv2d")
|
|
|
|
# m = InceptionA(3, 6)
|
|
# generate_module_test_data(m, "inception_a")
|
|
|
|
# m = InceptionB(3)
|
|
# generate_module_test_data(m, "inception_b")
|
|
|
|
# m = InceptionC(3, 64)
|
|
# generate_module_test_data(m, "inception_c")
|
|
|
|
# m = InceptionD(3)
|
|
# generate_module_test_data(m, "inception_d")
|
|
|
|
m = InceptionE(3)
|
|
generate_module_test_data(m, "inception_e")
|
|
|
|
|
|
|