Update inception v3 readme

This commit is contained in:
2024-09-04 21:32:05 +02:00
parent c8bc6f7a39
commit 7b8c4bd811
3 changed files with 56 additions and 10 deletions

View File

@@ -35,8 +35,6 @@ def export_model_weights(model: torch.nn.Module, filename):
tensor_data += tensor_bytes
# print(model.named_buffers)
# Add buffers (for running_mean and running_var)
for name, buf in model.named_buffers():
if "running_mean" not in name and "running_var" not in name:
@@ -76,9 +74,7 @@ def predict(model, image_path, resize=299, crop=299, preprocess=None):
)
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(
0
) # create a mini-batch as expected by the model
input_batch = input_tensor.unsqueeze(0)
# move the input and model to GPU for speed if available
if torch.cuda.is_available():
@@ -86,5 +82,5 @@ def predict(model, image_path, resize=299, crop=299, preprocess=None):
model.to("cuda")
with torch.no_grad():
output = model(input_batch)
return torch.argmax(output)
output = model(input_batch)
return torch.argmax(output).item()