diff --git a/examples/inception_v3/inception_v3.py b/examples/inception_v3/inception_v3.py index 0570af7..869ee04 100644 --- a/examples/inception_v3/inception_v3.py +++ b/examples/inception_v3/inception_v3.py @@ -1,17 +1,14 @@ -import torch import torchvision import sys -from torchsummary import summary - -inception = torchvision.models.inception_v3(weights=torchvision.models.Inception_V3_Weights.DEFAULT) -inception.eval() - sys.path.append('../../tools') # Ugly hack from utils import export_model_weights, print_model_parameters -print_model_parameters(inception) # print layer names and number of parameters -inception.cuda() +if __name__ == "__main__": + inception = torchvision.models.inception_v3(weights=torchvision.models.Inception_V3_Weights.DEFAULT) + inception.eval() -summary(inception, (3, 299, 299)) \ No newline at end of file + print_model_parameters(inception) # print layer names and number of parameters + + export_model_weights(inception, 'inception_v3_weights.bin') \ No newline at end of file