@@ -107,7 +107,7 @@ def _initialize_weights(self):
107
107
108
108
# Load pretrained model weights
109
109
model_url = 'https://s3.amazonaws.com/pytorch/test_data/export/superres_epoch100-44c6958e.pth'
110
- batch_size = 1 # just a random number
110
+ batch_size = 64 # just a random number
111
111
112
112
# Initialize model with the pretrained weights
113
113
map_location = lambda storage , loc : storage
@@ -218,6 +218,32 @@ def to_numpy(tensor):
218
218
# ONNX exporter, so please contact us in that case.
219
219
#
220
220
221
+ ######################################################################
222
+ # Timing Comparison Between Models
223
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
224
+ #
225
+
226
+ ######################################################################
227
+ # Since ONNX models optimize for inference speed, running the same
228
+ # data on an ONNX model instead of a native pytorch model should result in an
229
+ # improvement of up to 2x. Improvement is more pronounced with higher batch sizes.
230
+
231
+
232
+ import time
233
+
234
+ x = torch .randn (batch_size , 1 , 224 , 224 , requires_grad = True )
235
+
236
+ start = time .time ()
237
+ torch_out = torch_model (x )
238
+ end = time .time ()
239
+ print (f"Inference of Pytorch model used { end - start } seconds" )
240
+
241
+ ort_inputs = {ort_session .get_inputs ()[0 ].name : to_numpy (x )}
242
+ start = time .time ()
243
+ ort_outs = ort_session .run (None , ort_inputs )
244
+ end = time .time ()
245
+ print (f"Inference of ONNX model used { end - start } seconds" )
246
+
221
247
222
248
######################################################################
223
249
# Running the model on an image using ONNX Runtime
@@ -301,10 +327,20 @@ def to_numpy(tensor):
301
327
# Save the image, we will compare this with the output image from mobile device
302
328
final_img .save ("./_static/img/cat_superres_with_ort.jpg" )
303
329
330
+ # Save resized original image (without super-resolution)
331
+ img = transforms .Resize ([img_out_y .size [0 ], img_out_y .size [1 ]])(img )
332
+ img .save ("cat_resized.jpg" )
304
333
305
334
######################################################################
335
+ # Here is the comparison between the two images:
336
+ #
337
+ # .. figure:: /_static/img/cat_resized.jpg
338
+ #
339
+ # Low-resolution image
340
+ #
306
341
# .. figure:: /_static/img/cat_superres_with_ort.jpg
307
- # :alt: output\_cat
342
+ #
343
+ # Image after super-resolution
308
344
#
309
345
#
310
346
# ONNX Runtime being a cross platform engine, you can run it across
@@ -313,7 +349,7 @@ def to_numpy(tensor):
313
349
# ONNX Runtime can also be deployed to the cloud for model inferencing
314
350
# using Azure Machine Learning Services. More information `here <https://docs.microsoft.com/en-us/azure/machine-learning/service/concept-onnx>`__.
315
351
#
316
- # More information about ONNX Runtime's performance `here <https://github.com/microsoft/onnxruntime#high- performance>`__.
352
+ # More information about ONNX Runtime's performance `here <https://onnxruntime.ai/docs/ performance>`__.
317
353
#
318
354
#
319
355
# For more information about ONNX Runtime `here <https://github.com/microsoft/onnxruntime>`__.
0 commit comments