Files
CUDANet/examples/alexnet/alexnet.py
2024-05-30 19:31:28 +02:00

14 lines
435 B
Python

import sys
import torchvision
sys.path.append('../../tools') # Ugly hack
from utils import export_model_weights, print_model_parameters
if __name__ == "__main__":
alexnet = torchvision.models.alexnet(weights=torchvision.models.AlexNet_Weights.DEFAULT)
print_model_parameters(alexnet) # print layer names and number of parameters
export_model_weights(alexnet, 'alexnet_weights.bin')
# predict(alexnet, 'cat.jpg')