Skip to content

Exporting a Model from PyTorch to ONNX tutorial - Docathon #2935

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jun 18, 2024
Binary file added _static/img/cat_resized.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
34 changes: 32 additions & 2 deletions advanced_source/super_resolution_with_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def _initialize_weights(self):

# Load pretrained model weights
model_url = 'https://s3.amazonaws.com/pytorch/test_data/export/superres_epoch100-44c6958e.pth'
batch_size = 1 # just a random number
batch_size = 64 # just a random number

# Initialize model with the pretrained weights
map_location = lambda storage, loc: storage
Expand Down Expand Up @@ -218,6 +218,32 @@ def to_numpy(tensor):
# ONNX exporter, so please contact us in that case.
#

######################################################################
# Timing Comparison Between Models
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#

######################################################################
# Since ONNX models optimize for inference speed, running the same
# data on an ONNX model instead of a native pytorch model should result in an
# improvement of up to 2x. Improvement is more pronounced with higher batch sizes.


import time

x = torch.randn(batch_size, 1, 224, 224, requires_grad=True)

start = time.time()
torch_out = torch_model(x)
end = time.time()
print(f"Inference of Pytorch model used {end - start} seconds")

ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
start = time.time()
ort_outs = ort_session.run(None, ort_inputs)
end = time.time()
print(f"Inference of ONNX model used {end - start} seconds")


######################################################################
# Running the model on an image using ONNX Runtime
Expand Down Expand Up @@ -301,8 +327,12 @@ def to_numpy(tensor):
# Save the image, we will compare this with the output image from mobile device
final_img.save("./_static/img/cat_superres_with_ort.jpg")

# Save resized original image (without super-resolution)
img = transforms.Resize([img_out_y.size[0], img_out_y.size[1]])(img)
img.save("cat_resized.jpg")

######################################################################
# .. figure:: /_static/img/cat_resized.jpg
# .. figure:: /_static/img/cat_superres_with_ort.jpg
# :alt: output\_cat
#
Expand All @@ -313,7 +343,7 @@ def to_numpy(tensor):
# ONNX Runtime can also be deployed to the cloud for model inferencing
# using Azure Machine Learning Services. More information `here <https://docs.microsoft.com/en-us/azure/machine-learning/service/concept-onnx>`__.
#
# More information about ONNX Runtime's performance `here <https://github.com/microsoft/onnxruntime#high-performance>`__.
# More information about ONNX Runtime's performance `here <https://onnxruntime.ai/docs/performance>`__.
#
#
# For more information about ONNX Runtime `here <https://github.com/microsoft/onnxruntime>`__.
Expand Down
Loading