Skip to content

Commit 42fe906

Browse files
authored
Merge branch 'main' into improve_dqn_tutorial
2 parents 39b20e7 + be898cb commit 42fe906

File tree

2 files changed

+39
-3
lines changed

2 files changed

+39
-3
lines changed

Diff for: _static/img/cat_resized.jpg

39.2 KB
Loading

Diff for: advanced_source/super_resolution_with_onnxruntime.py

+39-3
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def _initialize_weights(self):
107107

108108
# Load pretrained model weights
109109
model_url = 'https://s3.amazonaws.com/pytorch/test_data/export/superres_epoch100-44c6958e.pth'
110-
batch_size = 1 # just a random number
110+
batch_size = 64 # just a random number
111111

112112
# Initialize model with the pretrained weights
113113
map_location = lambda storage, loc: storage
@@ -218,6 +218,32 @@ def to_numpy(tensor):
218218
# ONNX exporter, so please contact us in that case.
219219
#
220220

221+
######################################################################
222+
# Timing Comparison Between Models
223+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
224+
#
225+
226+
######################################################################
227+
# Since ONNX models optimize for inference speed, running the same
228+
# data on an ONNX model instead of a native pytorch model should result in an
229+
# improvement of up to 2x. Improvement is more pronounced with higher batch sizes.
230+
231+
232+
import time
233+
234+
x = torch.randn(batch_size, 1, 224, 224, requires_grad=True)
235+
236+
start = time.time()
237+
torch_out = torch_model(x)
238+
end = time.time()
239+
print(f"Inference of Pytorch model used {end - start} seconds")
240+
241+
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
242+
start = time.time()
243+
ort_outs = ort_session.run(None, ort_inputs)
244+
end = time.time()
245+
print(f"Inference of ONNX model used {end - start} seconds")
246+
221247

222248
######################################################################
223249
# Running the model on an image using ONNX Runtime
@@ -301,10 +327,20 @@ def to_numpy(tensor):
301327
# Save the image, we will compare this with the output image from mobile device
302328
final_img.save("./_static/img/cat_superres_with_ort.jpg")
303329

330+
# Save resized original image (without super-resolution)
331+
img = transforms.Resize([img_out_y.size[0], img_out_y.size[1]])(img)
332+
img.save("cat_resized.jpg")
304333

305334
######################################################################
335+
# Here is the comparison between the two images:
336+
#
337+
# .. figure:: /_static/img/cat_resized.jpg
338+
#
339+
# Low-resolution image
340+
#
306341
# .. figure:: /_static/img/cat_superres_with_ort.jpg
307-
# :alt: output\_cat
342+
#
343+
# Image after super-resolution
308344
#
309345
#
310346
# ONNX Runtime being a cross platform engine, you can run it across
@@ -313,7 +349,7 @@ def to_numpy(tensor):
313349
# ONNX Runtime can also be deployed to the cloud for model inferencing
314350
# using Azure Machine Learning Services. More information `here <https://docs.microsoft.com/en-us/azure/machine-learning/service/concept-onnx>`__.
315351
#
316-
# More information about ONNX Runtime's performance `here <https://github.com/microsoft/onnxruntime#high-performance>`__.
352+
# More information about ONNX Runtime's performance `here <https://onnxruntime.ai/docs/performance>`__.
317353
#
318354
#
319355
# For more information about ONNX Runtime `here <https://github.com/microsoft/onnxruntime>`__.

0 commit comments

Comments
 (0)