mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-11-05 17:34:21 +00:00
Update alexnet preprocessing
This commit is contained in:
@@ -3,11 +3,17 @@ import sys
|
||||
import torchvision
|
||||
|
||||
sys.path.append('../../tools') # Ugly hack
|
||||
from utils import export_model_weights, print_model_parameters
|
||||
from utils import export_model_weights, print_model_parameters, predict
|
||||
|
||||
if __name__ == "__main__":
|
||||
alexnet = torchvision.models.alexnet(weights=torchvision.models.AlexNet_Weights.DEFAULT)
|
||||
print_model_parameters(alexnet) # print layer names and number of parameters
|
||||
export_model_weights(alexnet, 'alexnet_weights.bin')
|
||||
# predict(alexnet, 'cat.jpg')
|
||||
|
||||
weights = torchvision.models.AlexNet_Weights.DEFAULT
|
||||
alexnet = torchvision.models.alexnet(weights=weights)
|
||||
|
||||
# print_model_parameters(alexnet) # print layer names and number of parameters
|
||||
export_model_weights(alexnet, 'alexnet_weights.bin')
|
||||
|
||||
# class_labels = weights.meta["categories"]
|
||||
# prediction = predict(alexnet, "margot.jpg")
|
||||
# print(prediction, class_labels[prediction])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user