Skip to content

Commit 32a8626

Browse files
authored
Cherry-pick : Add documentation for dynamo.compile backend (#2389) (#2416)
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 844afc6 commit 32a8626

File tree

3 files changed

+100
-3
lines changed

3 files changed

+100
-3
lines changed

docsrc/index.rst

+2
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ User Guide
4141
* :ref:`creating_a_ts_mod`
4242
* :ref:`getting_started_with_fx`
4343
* :ref:`torch_compile`
44+
* :ref:`dynamo_export`
4445
* :ref:`ptq`
4546
* :ref:`runtime`
4647
* :ref:`saving_models`
@@ -56,6 +57,7 @@ User Guide
5657
user_guide/creating_torchscript_module_in_python
5758
user_guide/getting_started_with_fx_path
5859
user_guide/torch_compile
60+
user_guide/dynamo_export
5961
user_guide/ptq
6062
user_guide/runtime
6163
user_guide/saving_models

docsrc/user_guide/dynamo_export.rst

+82
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
.. _dynamo_export:
2+
3+
Torch-TensorRT Dynamo Backend
4+
=============================================
5+
.. currentmodule:: torch_tensorrt.dynamo
6+
7+
.. automodule:: torch_tensorrt.dynamo
8+
:members:
9+
:undoc-members:
10+
:show-inheritance:
11+
12+
This guide presents Torch-TensorRT dynamo backend which optimizes Pytorch models
13+
using TensorRT in an Ahead-Of-Time fashion.
14+
15+
Using the Dynamo backend
16+
----------------------------------------
17+
Pytorch 2.1 introduced ``torch.export`` APIs which
18+
can export graphs from Pytorch programs into ``ExportedProgram`` objects. Torch-TensorRT dynamo
19+
backend compiles these ``ExportedProgram`` objects and optimizes them using TensorRT. Here's a simple
20+
usage of the dynamo backend
21+
22+
.. code-block:: python
23+
24+
import torch
25+
import torch_tensorrt
26+
27+
model = MyModel().eval().cuda()
28+
inputs = [torch.randn((1, 3, 224, 224), dtype=torch.float32).cuda()]
29+
exp_program = torch.export.export(model, tuple(inputs))
30+
trt_gm = torch_tensorrt.dynamo.compile(exp_program, inputs) # Output is a torch.fx.GraphModule
31+
trt_gm(*inputs)
32+
33+
.. note:: ``torch_tensorrt.dynamo.compile`` is the main API for users to interact with Torch-TensorRT dynamo backend. The input type of the model should be ``ExportedProgram`` (ideally the output of ``torch.export.export`` or ``torch_tensorrt.dynamo.trace`` (discussed in the section below)) and output type is a ``torch.fx.GraphModule`` object.
34+
35+
Customizeable Settings
36+
----------------------
37+
38+
There are lot of options for users to customize their settings for optimizing with TensorRT.
39+
Some of the frequently used options are as follows:
40+
41+
* ``inputs`` - For static shapes, this can be a list of torch tensors or `torch_tensorrt.Input` objects. For dynamic shapes, this should be a list of ``torch_tensorrt.Input`` objects.
42+
* ``enabled_precisions`` - Set of precisions that TensorRT builder can use during optimization.
43+
* ``truncate_long_and_double`` - Truncates long and double values to int and floats respectively.
44+
* ``torch_executed_ops`` - Operators which are forced to be executed by Torch.
45+
* ``min_block_size`` - Minimum number of consecutive operators required to be executed as a TensorRT segment.
46+
47+
The complete list of options can be found `here <https://github.com/pytorch/TensorRT/blob/123a486d6644a5bbeeec33e2f32257349acc0b8f/py/torch_tensorrt/dynamo/compile.py#L51-L77>`_
48+
49+
.. note:: We do not support INT precision currently in Dynamo. Support for this currently exists in
50+
our Torchscript IR. We plan to implement similar support for dynamo in our next release.
51+
52+
Under the hood
53+
--------------
54+
55+
Under the hood, ``torch_tensorrt.dynamo.compile`` performs the following on the graph.
56+
57+
* Lowering - Applies lowering passes to add/remove operators for optimal conversion.
58+
* Partitioning - Partitions the graph into Pytorch and TensorRT segments based on the ``min_block_size`` and ``torch_executed_ops`` field.
59+
* Conversion - Pytorch ops get converted into TensorRT ops in this phase.
60+
* Optimization - Post conversion, we build the TensorRT engine and embed this inside the pytorch graph.
61+
62+
Tracing
63+
-------
64+
65+
``torch_tensorrt.dynamo.trace`` can be used to trace a Pytorch graphs and produce ``ExportedProgram``.
66+
This internally performs some decompositions of operators for downstream optimization.
67+
The ``ExportedProgram`` can then be used with ``torch_tensorrt.dynamo.compile`` API.
68+
If you have dynamic input shapes in your model, you can use this ``torch_tensorrt.dynamo.trace`` to export
69+
the model with dynamic shapes. Alternatively, you can use ``torch.export`` `with constraints <https://pytorch.org/docs/stable/export.html#expressing-dynamism>`_ directly as well.
70+
71+
.. code-block:: python
72+
73+
import torch
74+
import torch_tensorrt
75+
76+
inputs = [torch_tensorrt.Input(min_shape=(1, 3, 224, 224),
77+
opt_shape=(4, 3, 224, 224),
78+
max_shape=(8, 3, 224, 224),
79+
dtype=torch.float32)]
80+
model = MyModel().eval()
81+
exp_program = torch_tensorrt.dynamo.trace(model, inputs)
82+

docsrc/user_guide/saving_models.rst

+16-3
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,25 @@
1-
.. _runtime:
1+
.. _saving_models:
22

33
Saving models compiled with Torch-TensorRT
44
====================================
5+
.. currentmodule:: torch_tensorrt.dynamo
56

7+
.. automodule:: torch_tensorrt.dynamo
8+
:members:
9+
:undoc-members:
10+
:show-inheritance:
11+
612
Saving models compiled with Torch-TensorRT varies slightly with the `ir` that has been used for compilation.
713

8-
1) Dynamo IR
14+
Dynamo IR
15+
-------------
916

1017
Starting with 2.1 release of Torch-TensorRT, we are switching the default compilation to be dynamo based.
1118
The output of `ir=dynamo` compilation is a `torch.fx.GraphModule` object. There are two ways to save these objects
1219

1320
a) Converting to Torchscript
21+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
22+
1423
`torch.fx.GraphModule` objects cannot be serialized directly. Hence we use `torch.jit.trace` to convert this into a `ScriptModule` object which can be saved to disk.
1524
The following code illustrates this approach.
1625

@@ -30,6 +39,8 @@ The following code illustrates this approach.
3039
model(inputs)
3140
3241
b) ExportedProgram
42+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
43+
3344
`torch.export.ExportedProgram` is a new format introduced in Pytorch 2.1. After we compile a Pytorch module using Torch-TensorRT, the resultant
3445
`torch.fx.GraphModule` along with additional metadata can be used to create `ExportedProgram` which can be saved and loaded from disk.
3546

@@ -54,7 +65,9 @@ This is needed as `torch._export` serialization cannot handle serializing and de
5465

5566
NOTE: This way of saving the models using `ExportedProgram` is experimental. Here is a known issue : https://github.com/pytorch/TensorRT/issues/2341
5667

57-
2) Torchscript IR
68+
69+
Torchscript IR
70+
-------------
5871

5972
In Torch-TensorRT 1.X versions, the primary way to compile and run inference with Torch-TensorRT is using Torchscript IR.
6073
This behavior stays the same in 2.X versions as well.

0 commit comments

Comments
 (0)