From 9a9b034ce52ba1dcd8488196959a7d9083886b60 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Mon, 22 Apr 2024 21:58:32 +0200 Subject: [PATCH] Add torch predict function --- examples/alexnet/alexnet.py | 38 +++++++++++++++++++++++++++---------- 1 file changed, 28 insertions(+), 10 deletions(-) diff --git a/examples/alexnet/alexnet.py b/examples/alexnet/alexnet.py index df64ebb..de401d2 100644 --- a/examples/alexnet/alexnet.py +++ b/examples/alexnet/alexnet.py @@ -1,20 +1,38 @@ -import torchvision -import torch import sys -from torchsummary import summary +import torch +import torchvision +from PIL import Image +from torchvision import transforms sys.path.append('../../tools') # Ugly hack from utils import export_model_weights, print_model_parameters + +def predict(model, image_path): + input_image = Image.open(image_path) + preprocess = transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(227), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ]) + input_tensor = preprocess(input_image) + input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model + + # move the input and model to GPU for speed if available + if torch.cuda.is_available(): + input_batch = input_batch.to('cuda') + model.to('cuda') + + with torch.no_grad(): + output = model(input_batch) + print(torch.argmax(output)) + + if __name__ == "__main__": - alexnet = torchvision.models.alexnet(pretrained=True) + alexnet = torchvision.models.alexnet(weights=torchvision.models.AlexNet_Weights.DEFAULT) print_model_parameters(alexnet) # print layer names and number of parameters export_model_weights(alexnet, 'alexnet_weights.bin') - print() - - if torch.cuda.is_available(): - alexnet.cuda() - - summary(alexnet, (3, 227, 227)) + # predict('cat.jpg')