Update alexnet preprocessing

This commit is contained in:
2024-09-04 21:41:26 +02:00
parent 7b8c4bd811
commit e7ec6c91f8
2 changed files with 31 additions and 10 deletions

View File

@@ -3,11 +3,17 @@ import sys
import torchvision
sys.path.append('../../tools') # Ugly hack
from utils import export_model_weights, print_model_parameters
from utils import export_model_weights, print_model_parameters, predict
if __name__ == "__main__":
alexnet = torchvision.models.alexnet(weights=torchvision.models.AlexNet_Weights.DEFAULT)
print_model_parameters(alexnet) # print layer names and number of parameters
export_model_weights(alexnet, 'alexnet_weights.bin')
# predict(alexnet, 'cat.jpg')
weights = torchvision.models.AlexNet_Weights.DEFAULT
alexnet = torchvision.models.alexnet(weights=weights)
# print_model_parameters(alexnet) # print layer names and number of parameters
export_model_weights(alexnet, 'alexnet_weights.bin')
# class_labels = weights.meta["categories"]
# prediction = predict(alexnet, "margot.jpg")
# print(prediction, class_labels[prediction])