From 09e64f8037a53f042c3ef0abbb20e52a4ef14e5c Mon Sep 17 00:00:00 2001 From: LordMathis Date: Thu, 30 May 2024 19:31:28 +0200 Subject: [PATCH] Reformat python files --- examples/alexnet/alexnet.py | 25 ------------------------- examples/inception_v3/inception_v3.py | 15 +++++++-------- 2 files changed, 7 insertions(+), 33 deletions(-) diff --git a/examples/alexnet/alexnet.py b/examples/alexnet/alexnet.py index 63bb580..bcd0a36 100644 --- a/examples/alexnet/alexnet.py +++ b/examples/alexnet/alexnet.py @@ -1,35 +1,10 @@ import sys -import torch import torchvision -from PIL import Image -from torchvision import transforms sys.path.append('../../tools') # Ugly hack from utils import export_model_weights, print_model_parameters - -def predict(model, image_path): - input_image = Image.open(image_path) - preprocess = transforms.Compose([ - transforms.Resize(256), - transforms.CenterCrop(227), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), - ]) - input_tensor = preprocess(input_image) - input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model - - # move the input and model to GPU for speed if available - if torch.cuda.is_available(): - input_batch = input_batch.to('cuda') - model.to('cuda') - - with torch.no_grad(): - output = model(input_batch) - print(torch.argmax(output)) - - if __name__ == "__main__": alexnet = torchvision.models.alexnet(weights=torchvision.models.AlexNet_Weights.DEFAULT) print_model_parameters(alexnet) # print layer names and number of parameters diff --git a/examples/inception_v3/inception_v3.py b/examples/inception_v3/inception_v3.py index f052fd1..a5eadad 100644 --- a/examples/inception_v3/inception_v3.py +++ b/examples/inception_v3/inception_v3.py @@ -1,16 +1,15 @@ import torchvision import sys -sys.path.append('../../tools') # Ugly hack -from utils import export_model_weights, print_model_parameters, predict +sys.path.append("../../tools") # Ugly hack +from utils import export_model_weights, print_model_parameters if __name__ == "__main__": - inception = torchvision.models.inception_v3(weights=torchvision.models.Inception_V3_Weights.DEFAULT) + inception = torchvision.models.inception_v3( + weights=torchvision.models.Inception_V3_Weights.DEFAULT + ) inception.eval() - # print_model_parameters(inception) # print layer names and number of parameters - - # export_model_weights(inception, 'inception_v3_weights.bin') - - print(predict(inception, "./margot.jpg")) + print_model_parameters(inception) # print layer names and number of parameters + export_model_weights(inception, "inception_v3_weights.bin")