mirror of
https://github.com/lordmathis/CUDANet.git
synced 2025-11-06 01:34:22 +00:00
Update inception v3 readme
This commit is contained in:
@@ -35,8 +35,6 @@ def export_model_weights(model: torch.nn.Module, filename):
|
||||
|
||||
tensor_data += tensor_bytes
|
||||
|
||||
# print(model.named_buffers)
|
||||
|
||||
# Add buffers (for running_mean and running_var)
|
||||
for name, buf in model.named_buffers():
|
||||
if "running_mean" not in name and "running_var" not in name:
|
||||
@@ -76,9 +74,7 @@ def predict(model, image_path, resize=299, crop=299, preprocess=None):
|
||||
)
|
||||
|
||||
input_tensor = preprocess(input_image)
|
||||
input_batch = input_tensor.unsqueeze(
|
||||
0
|
||||
) # create a mini-batch as expected by the model
|
||||
input_batch = input_tensor.unsqueeze(0)
|
||||
|
||||
# move the input and model to GPU for speed if available
|
||||
if torch.cuda.is_available():
|
||||
@@ -86,5 +82,5 @@ def predict(model, image_path, resize=299, crop=299, preprocess=None):
|
||||
model.to("cuda")
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(input_batch)
|
||||
return torch.argmax(output)
|
||||
output = model(input_batch)
|
||||
return torch.argmax(output).item()
|
||||
|
||||
Reference in New Issue
Block a user