-
Notifications
You must be signed in to change notification settings - Fork 4.1k
[ONNX] Update API to torch.onnx.export(..., dynamo=True) #3223
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 24 commits
Commits
Show all changes
33 commits
Select commit
Hold shift + click to select a range
54a48a5
Fix torchrl scripts for PT 2.6 TorchRL>=0.6 (#3199)
vmoens d9660ce
update api
titaiwangms b476c7e
add torch.cond
titaiwangms aa3fc9e
add torch.compiler.set_stance tutorial (#3225)
williamwen42 d3fec71
Revert "add torch.compiler.set_stance tutorial (#3225)" (#3231)
williamwen42 56c0006
update registry
titaiwangms 7c77db6
fix
titaiwangms c29c22b
address formatting
titaiwangms 6bca4e4
reformatting
titaiwangms 6e416d1
words
titaiwangms 846cf83
removed printout and algin titile format
titaiwangms 1a2cc7a
refactor intro_onnx and simple_example
titaiwangms 7dc050f
revert those in 2.6 but not yet cherry-picks
titaiwangms e104e01
add coding head
titaiwangms f620302
add space
titaiwangms eddc1a1
Merge branch 'main' into titaiwang/dynamo_true_api
svekars c33929a
Merge branch 'main' into titaiwang/dynamo_true_api
svekars 159d6b0
Merge branch 'main' into titaiwang/dynamo_true_api
svekars 3027ebe
address reviews
titaiwangms 8adb7f3
Merge branch 'main' into titaiwang/dynamo_true_api
titaiwangms eb89a8a
Remove dot for consistency
svekars 3a1d4a3
fix ci
titaiwangms 1cf9731
Merge branch 'main' into titaiwang/dynamo_true_api
titaiwangms d3bb7e7
Merge branch 'main' into titaiwang/dynamo_true_api
svekars b64a434
fix misspelled words
titaiwangms d769747
Merge branch 'main' into titaiwang/dynamo_true_api
titaiwangms 0ba53e8
Merge branch 'main' into titaiwang/dynamo_true_api
titaiwangms e15eab6
Merge branch 'main' into titaiwang/dynamo_true_api
svekars c544eeb
Fix author links
svekars 9757ccc
Fix author links
svekars 16854f2
Merge branch 'main' into titaiwang/dynamo_true_api
svekars f9ac5b5
fix ONNX RUNTIME in comments and opsets names
titaiwangms bfad889
Merge branch 'main' into titaiwang/dynamo_true_api
titaiwangms File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Binary file not shown.
Binary file not shown.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
184 changes: 184 additions & 0 deletions
184
beginner_source/onnx/export_control_flow_model_to_onnx_tutorial.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
`Introduction to ONNX <intro_onnx.html>`_ || | ||
`Exporting a PyTorch model to ONNX <export_simple_model_to_onnx_tutorial.html>`_ || | ||
`Extending the ONNX exporter operator support <onnx_registry_tutorial.html>`_ || | ||
**`Export a model with control flow to ONNX** | ||
|
||
Export a model with control flow to ONNX | ||
======================================== | ||
|
||
**Author**: `Xavier Dupré <https://github.com/xadupre>`_. | ||
svekars marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
|
||
|
||
############################################################################### | ||
# Overview | ||
# -------- | ||
# | ||
# This tutorial demonstrates how to handle control flow logic while exporting | ||
# a PyTorch model to ONNX. It highlights the challenges of exporting | ||
# conditional statements directly and provides solutions to circumvent them. | ||
# | ||
# Conditional logic cannot be exported into ONNX unless they refactored | ||
# to use :func:`torch.cond`. Let's start with a simple model | ||
# implementing a test. | ||
# | ||
# What you will learn: | ||
# | ||
# - How to refactor the model to use :func:`torch.cond` for exporting. | ||
# - How to export a model with control flow logic to ONNX. | ||
# - How to optimize the exported model using the ONNX optimizer. | ||
# | ||
# Prerequisites | ||
# ~~~~~~~~~~~~~ | ||
# | ||
# * ``torch >= 2.6`` | ||
|
||
|
||
import torch | ||
|
||
############################################################################### | ||
# Define the Models | ||
# ----------------- | ||
# | ||
# Two models are defined: | ||
# | ||
# ForwardWithControlFlowTest: A model with a forward method containing an | ||
# if-else conditional. | ||
# | ||
# ModelWithControlFlowTest: A model that incorporates ForwardWithControlFlowTest | ||
# as part of a simple multi-layer perceptron (MLP). The models are tested with | ||
# a random input tensor to confirm they execute as expected. | ||
|
||
class ForwardWithControlFlowTest(torch.nn.Module): | ||
def forward(self, x): | ||
if x.sum(): | ||
return x * 2 | ||
return -x | ||
|
||
|
||
class ModelWithControlFlowTest(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.mlp = torch.nn.Sequential( | ||
torch.nn.Linear(3, 2), | ||
torch.nn.Linear(2, 1), | ||
ForwardWithControlFlowTest(), | ||
) | ||
|
||
def forward(self, x): | ||
out = self.mlp(x) | ||
return out | ||
|
||
|
||
model = ModelWithControlFlowTest() | ||
|
||
|
||
############################################################################### | ||
# Exporting the Model: First Attempt | ||
# ---------------------------------- | ||
# | ||
# Exporting this model using torch.export.export fails because the control | ||
# flow logic in the forward pass creates a graph break that the exporter cannot | ||
# handle. This behavior is expected, as conditional logic not written using | ||
# torch.cond is unsupported. | ||
# | ||
# A try-except block is used to capture the expected failure during the export | ||
# process. If the export unexpectedly succeeds, an AssertionError is raised. | ||
|
||
x = torch.randn(3) | ||
model(x) | ||
|
||
try: | ||
torch.export.export(model, (x,), strict=False) | ||
raise AssertionError("This export should failed unless PyTorch now supports this model.") | ||
except Exception as e: | ||
print(e) | ||
|
||
############################################################################### | ||
# Using torch.onnx.export with JIT Tracing | ||
# ---------------------------------------- | ||
# | ||
# When exporting the model using torch.onnx.export with the dynamo=True | ||
# argument, the exporter defaults to using JIT tracing. This fallback allows | ||
# the model to export, but the resulting ONNX graph may not faithfully represent | ||
# the original model logic due to the limitations of tracing. | ||
|
||
|
||
onnx_program = torch.onnx.export(model, (x,), dynamo=True) | ||
print(onnx_program.model) | ||
titaiwangms marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
############################################################################### | ||
# Suggested Patch: Refactoring with torch.cond | ||
# -------------------------------------------- | ||
# | ||
# To make the control flow exportable, the tutorial demonstrates replacing the | ||
# forward method in ForwardWithControlFlowTest with a refactored version that | ||
# uses torch.cond. | ||
# | ||
# Details of the Refactoring: | ||
# | ||
# Two helper functions (identity2 and neg) represent the branches of the conditional logic: | ||
# * torch.cond is used to specify the condition and the two branches along with the input arguments. | ||
# * The updated forward method is then dynamically assigned to the ForwardWithControlFlowTest instance within the model. A list of submodules is printed to confirm the replacement. | ||
|
||
def new_forward(x): | ||
def identity2(x): | ||
return x * 2 | ||
|
||
def neg(x): | ||
return -x | ||
|
||
return torch.cond(x.sum() > 0, identity2, neg, (x,)) | ||
|
||
|
||
print("the list of submodules") | ||
for name, mod in model.named_modules(): | ||
print(name, type(mod)) | ||
if isinstance(mod, ForwardWithControlFlowTest): | ||
mod.forward = new_forward | ||
|
||
############################################################################### | ||
# Let's see what the fx graph looks like. | ||
|
||
print(torch.export.export(model, (x,), strict=False)) | ||
titaiwangms marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
############################################################################### | ||
# Let's export again. | ||
|
||
onnx_program = torch.onnx.export(model, (x,), dynamo=True) | ||
print(onnx_program.model) | ||
titaiwangms marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
titaiwangms marked this conversation as resolved.
Show resolved
Hide resolved
|
||
############################################################################### | ||
# We can optimize the model and get rid of the model local functions created to capture the control flow branches. | ||
|
||
onnx_program.optimize() | ||
print(onnx_program.model) | ||
titaiwangms marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
############################################################################### | ||
# Conclusion | ||
# ---------- | ||
# | ||
# This tutorial demonstrates the challenges of exporting models with conditional | ||
# logic to ONNX and presents a practical solution using torch.cond. | ||
# While the default exporters may fail or produce imperfect graphs, refactoring the | ||
# model's logic ensures compatibility and generates a faithful ONNX representation. | ||
# | ||
# By understanding these techniques, we can overcome common pitfalls when | ||
# working with control flow in PyTorch models and ensure smooth integration with ONNX workflows. | ||
# | ||
# Further reading | ||
# --------------- | ||
# | ||
# The list below refers to tutorials that ranges from basic examples to advanced scenarios, | ||
# not necessarily in the order they are listed. | ||
# Feel free to jump directly to specific topics of your interest or | ||
# sit tight and have fun going through all of them to learn all there is about the ONNX exporter. | ||
# | ||
# .. include:: /beginner_source/onnx/onnx_toc.txt | ||
# | ||
# .. toctree:: | ||
# :hidden: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,26 +2,27 @@ | |
""" | ||
`Introduction to ONNX <intro_onnx.html>`_ || | ||
**Exporting a PyTorch model to ONNX** || | ||
`Extending the ONNX Registry <onnx_registry_tutorial.html>`_ | ||
`Extending the ONNX exporter operator support <onnx_registry_tutorial.html>`_ || | ||
`Export a model with control flow to ONNX <export_control_flow_model_to_onnx_tutorial.html>`_ | ||
|
||
Export a PyTorch model to ONNX | ||
============================== | ||
|
||
**Author**: `Ti-Tai Wang <https://github.com/titaiwangms>`_ and `Xavier Dupré <https://github.com/xadupre>`_ | ||
**Author**: `Ti-Tai Wang <https://github.com/titaiwangms>`_, Justin Chu ([email protected]) and Thiago Crepaldi <https://github.com/thiagocrepaldi>`_. | ||
svekars marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
.. note:: | ||
As of PyTorch 2.1, there are two versions of ONNX Exporter. | ||
As of PyTorch 2.5, there are two versions of ONNX Exporter. | ||
|
||
* ``torch.onnx.dynamo_export`` is the newest (still in beta) exporter based on the TorchDynamo technology released with PyTorch 2.0 | ||
* ``torch.onnx.export`` is based on TorchScript backend and has been available since PyTorch 1.2.0 | ||
* ``torch.onnx.export(..., dynamo=True)`` is the newest (still in beta) exporter using ``torch.export`` and Torch FX to capture the graph. It was released with PyTorch 2.5 | ||
* ``torch.onnx.export`` uses TorchScript and has been available since PyTorch 1.2.0 | ||
|
||
""" | ||
|
||
############################################################################### | ||
# In the `60 Minute Blitz <https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html>`_, | ||
# we had the opportunity to learn about PyTorch at a high level and train a small neural network to classify images. | ||
# In this tutorial, we are going to expand this to describe how to convert a model defined in PyTorch into the | ||
# ONNX format using TorchDynamo and the ``torch.onnx.dynamo_export`` ONNX exporter. | ||
# ONNX format using the ``torch.onnx.export(..., dynamo=True)`` ONNX exporter. | ||
# | ||
# While PyTorch is great for iterating on the development of models, the model can be deployed to production | ||
# using different formats, including `ONNX <https://onnx.ai/>`_ (Open Neural Network Exchange)! | ||
|
@@ -47,8 +48,7 @@ | |
# | ||
# .. code-block:: bash | ||
# | ||
# pip install onnx | ||
# pip install onnxscript | ||
# pip install --upgrade onnx onnxscript | ||
# | ||
# 2. Author a simple image classifier model | ||
# ----------------------------------------- | ||
|
@@ -62,17 +62,16 @@ | |
import torch.nn.functional as F | ||
|
||
|
||
class MyModel(nn.Module): | ||
|
||
class ImageClassifierModel(nn.Module): | ||
def __init__(self): | ||
super(MyModel, self).__init__() | ||
super().__init__() | ||
self.conv1 = nn.Conv2d(1, 6, 5) | ||
self.conv2 = nn.Conv2d(6, 16, 5) | ||
self.fc1 = nn.Linear(16 * 5 * 5, 120) | ||
self.fc2 = nn.Linear(120, 84) | ||
self.fc3 = nn.Linear(84, 10) | ||
|
||
def forward(self, x): | ||
def forward(self, x: torch.Tensor): | ||
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2)) | ||
x = F.max_pool2d(F.relu(self.conv2(x)), 2) | ||
x = torch.flatten(x, 1) | ||
|
@@ -81,16 +80,27 @@ def forward(self, x): | |
x = self.fc3(x) | ||
return x | ||
|
||
|
||
###################################################################### | ||
# 3. Export the model to ONNX format | ||
# ---------------------------------- | ||
# | ||
# Now that we have our model defined, we need to instantiate it and create a random 32x32 input. | ||
# Next, we can export the model to ONNX format. | ||
|
||
torch_model = MyModel() | ||
torch_input = torch.randn(1, 1, 32, 32) | ||
onnx_program = torch.onnx.dynamo_export(torch_model, torch_input) | ||
torch_model = ImageClassifierModel() | ||
# Create example inputs for exporting the model. The inputs should be a tuple of tensors. | ||
example_inputs = (torch.randn(1, 1, 32, 32),) | ||
onnx_program = torch.onnx.export(torch_model, example_inputs, dynamo=True) | ||
|
||
###################################################################### | ||
# 3.5. (Optional) Optimize the ONNX model | ||
# --------------------------------------- | ||
# | ||
# The ONNX model can be optimized with constant folding, and elimination of redundant nodes. | ||
# The optimization is done in-place, so the original ONNX model is modified. | ||
|
||
onnx_program.optimize() | ||
|
||
###################################################################### | ||
# As we can see, we didn't need any code change to the model. | ||
|
@@ -102,13 +112,14 @@ def forward(self, x): | |
# Although having the exported model loaded in memory is useful in many applications, | ||
# we can save it to disk with the following code: | ||
|
||
onnx_program.save("my_image_classifier.onnx") | ||
onnx_program.save("image_classifier_model.onnx") | ||
|
||
###################################################################### | ||
# You can load the ONNX file back into memory and check if it is well formed with the following code: | ||
|
||
import onnx | ||
onnx_model = onnx.load("my_image_classifier.onnx") | ||
|
||
onnx_model = onnx.load("image_classifier_model.onnx") | ||
onnx.checker.check_model(onnx_model) | ||
|
||
###################################################################### | ||
|
@@ -124,7 +135,7 @@ def forward(self, x): | |
# :align: center | ||
# | ||
# | ||
# Once Netron is open, we can drag and drop our ``my_image_classifier.onnx`` file into the browser or select it after | ||
# Once Netron is open, we can drag and drop our ``image_classifier_model.onnx`` file into the browser or select it after | ||
# clicking the **Open model** button. | ||
# | ||
# .. image:: ../../_static/img/onnx/image_classifier_onnx_model_on_netron_web_ui.png | ||
|
@@ -155,16 +166,15 @@ def forward(self, x): | |
|
||
import onnxruntime | ||
|
||
onnx_input = [torch_input] | ||
print(f"Input length: {len(onnx_input)}") | ||
print(f"Sample input: {onnx_input}") | ||
onnx_inputs = [tensor.numpy(force=True) for tensor in example_inputs] | ||
print(f"Input length: {len(onnx_inputs)}") | ||
print(f"Sample input: {onnx_inputs}") | ||
|
||
ort_session = onnxruntime.InferenceSession("./my_image_classifier.onnx", providers=['CPUExecutionProvider']) | ||
ort_session = onnxruntime.InferenceSession( | ||
"./image_classifier_model.onnx", providers=["CPUExecutionProvider"] | ||
) | ||
|
||
def to_numpy(tensor): | ||
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() | ||
|
||
onnxruntime_input = {k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input)} | ||
onnxruntime_input = {input_arg.name: input_value for input_arg, input_value in zip(ort_session.get_inputs(), onnx_inputs)} | ||
|
||
# onnxruntime returns a list of outputs | ||
onnxruntime_outputs = ort_session.run(None, onnxruntime_input)[0] | ||
|
@@ -179,7 +189,7 @@ def to_numpy(tensor): | |
# For that, we need to execute the PyTorch model with the same input and compare the results with ONNX Runtime's. | ||
# Before comparing the results, we need to convert the PyTorch's output to match ONNX's format. | ||
|
||
torch_outputs = torch_model(torch_input) | ||
torch_outputs = torch_model(*example_inputs) | ||
|
||
assert len(torch_outputs) == len(onnxruntime_outputs) | ||
for torch_output, onnxruntime_output in zip(torch_outputs, onnxruntime_outputs): | ||
|
@@ -209,4 +219,4 @@ def to_numpy(tensor): | |
# | ||
# .. toctree:: | ||
# :hidden: | ||
# | ||
# |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you want to add a customcard for this tutorial in index.rst (like this) so it's discoverable on the landing page under the ONNX selector
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe you can you submit a follow up for this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice suggestion!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will do it as a follow up.