mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-11-06 01:34:22 +00:00
Update inception py
This commit is contained in:
@@ -1,17 +1,14 @@
|
|||||||
import torch
|
|
||||||
import torchvision
|
import torchvision
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from torchsummary import summary
|
|
||||||
|
|
||||||
inception = torchvision.models.inception_v3(weights=torchvision.models.Inception_V3_Weights.DEFAULT)
|
|
||||||
inception.eval()
|
|
||||||
|
|
||||||
sys.path.append('../../tools') # Ugly hack
|
sys.path.append('../../tools') # Ugly hack
|
||||||
from utils import export_model_weights, print_model_parameters
|
from utils import export_model_weights, print_model_parameters
|
||||||
|
|
||||||
print_model_parameters(inception) # print layer names and number of parameters
|
|
||||||
|
|
||||||
inception.cuda()
|
if __name__ == "__main__":
|
||||||
|
inception = torchvision.models.inception_v3(weights=torchvision.models.Inception_V3_Weights.DEFAULT)
|
||||||
|
inception.eval()
|
||||||
|
|
||||||
summary(inception, (3, 299, 299))
|
print_model_parameters(inception) # print layer names and number of parameters
|
||||||
|
|
||||||
|
export_model_weights(inception, 'inception_v3_weights.bin')
|
||||||
Reference in New Issue
Block a user