Skip to content

Commit 4a920a1

Browse files
authored
docs: Add documentation of torch.compile backend usage (#2363)
1 parent 7a1cdb3 commit 4a920a1

File tree

4 files changed

+118
-12
lines changed

4 files changed

+118
-12
lines changed

docsrc/index.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ User Guide
4040
------------
4141
* :ref:`creating_a_ts_mod`
4242
* :ref:`getting_started_with_fx`
43+
* :ref:`torch_compile`
4344
* :ref:`ptq`
4445
* :ref:`runtime`
4546
* :ref:`saving_models`
@@ -54,6 +55,7 @@ User Guide
5455

5556
user_guide/creating_torchscript_module_in_python
5657
user_guide/getting_started_with_fx_path
58+
user_guide/torch_compile
5759
user_guide/ptq
5860
user_guide/runtime
5961
user_guide/saving_models

docsrc/user_guide/dynamic_shapes.rst

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
.. _runtime:
1+
.. _dynamic_shapes:
22

33
Dynamic shapes with Torch-TensorRT
44
====================================
@@ -206,13 +206,3 @@ In the future, we plan to explore the option of compiling with dynamic shapes in
206206
# Recompilation happens with modified batch size
207207
inputs_bs2 = torch.randn((2, 3, 224, 224), dtype=torch.float32)
208208
trt_gm = torch_tensorrt.compile(model, ir="torch_compile", inputs_bs2)
209-
210-
211-
212-
213-
214-
215-
216-
217-
218-

docsrc/user_guide/torch_compile.rst

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
.. _torch_compile:
2+
3+
Torch-TensorRT `torch.compile` Backend
4+
======================================================
5+
.. currentmodule:: torch_tensorrt.dynamo
6+
7+
.. automodule:: torch_tensorrt.dynamo
8+
:members:
9+
:undoc-members:
10+
:show-inheritance:
11+
12+
This guide presents the Torch-TensorRT `torch.compile` backend: a deep learning compiler which uses TensorRT to accelerate JIT-style workflows across a wide variety of models.
13+
14+
Key Features
15+
--------------------------------------------
16+
17+
The primary goal of the Torch-TensorRT `torch.compile` backend is to enable Just-In-Time compilation workflows by combining the simplicity of `torch.compile` API with the performance of TensorRT. Invoking the `torch.compile` backend is as simple as importing the `torch_tensorrt` package and specifying the backend:
18+
19+
.. code-block:: python
20+
21+
import torch_tensorrt
22+
...
23+
optimized_model = torch.compile(model, backend="torch_tensorrt", dynamic=False)
24+
25+
.. note:: Many additional customization options are available to the user. These will be discussed in further depth in this guide.
26+
27+
The backend can handle a variety of challenging model structures and offers a simple-to-use interface for effective acceleration of models. Additionally, it has many customization options to ensure the compilation process is fitting to the specific use case.
28+
29+
Customizeable Settings
30+
-----------------
31+
.. autoclass:: CompilationSettings
32+
33+
Custom Setting Usage
34+
^^^^^^^^^^^^^^^^^
35+
.. code-block:: python
36+
37+
import torch_tensorrt
38+
...
39+
optimized_model = torch.compile(model, backend="torch_tensorrt", dynamic=False,
40+
options={"truncate_long_and_double": True,
41+
"precision": torch.half,
42+
"debug": True,
43+
"min_block_size": 2,
44+
"torch_executed_ops": {"torch.ops.aten.sub.Tensor"},
45+
"optimization_level": 4,
46+
"use_python_runtime": False,})
47+
48+
.. note:: Quantization/INT8 support is slated for a future release; currently, we support FP16 and FP32 precision layers.
49+
50+
Compilation
51+
-----------------
52+
Compilation is triggered by passing inputs to the model, as so:
53+
54+
.. code-block:: python
55+
56+
import torch_tensorrt
57+
...
58+
# Causes model compilation to occur
59+
first_outputs = optimized_model(*inputs)
60+
61+
# Subsequent inference runs with the same, or similar inputs will not cause recompilation
62+
# For a full discussion of this, see "Recompilation Conditions" below
63+
second_outputs = optimized_model(*inputs)
64+
65+
After Compilation
66+
-----------------
67+
The compilation object can be used for inference within the Python session, and will recompile according to the recompilation conditions detailed below. In addition to general inference, the compilation process can be a helpful tool in determining model performance, current operator coverage, and feasibility of serialization. Each of these points will be covered in detail below.
68+
69+
Model Performance
70+
^^^^^^^^^^^^^^^^^
71+
The optimized model returned from `torch.compile` is useful for model benchmarking since it can automatically handle changes in the compilation context, or differing inputs that could require recompilation. When benchmarking inputs of varying distributions, batch sizes, or other criteria, this can save time.
72+
73+
Operator Coverage
74+
^^^^^^^^^^^^^^^^^
75+
Compilation is also a useful tool in determining operator coverage for a particular model. For instance, the following compilation command will display the operator coverage for each graph, but will not compile the model - effectively providing a "dryrun" mechanism:
76+
77+
.. code-block:: python
78+
79+
import torch_tensorrt
80+
...
81+
optimized_model = torch.compile(model, backend="torch_tensorrt", dynamic=False,
82+
options={"debug": True,
83+
"min_block_size": float("inf"),})
84+
85+
If key operators for your model are unsupported, see :ref:`dynamo_conversion` to contribute your own converters, or file an issue here: https://github.com/pytorch/TensorRT/issues.
86+
87+
Feasibility of Serialization
88+
^^^^^^^^^^^^^^^^^
89+
Compilation can also be helpful in demonstrating graph breaks and the feasibility of serialization of a particular model. For instance, if a model has no graph breaks and compiles successfully with the Torch-TensorRT backend, then that model should be compileable and serializeable via the `torch_tensorrt` Dynamo IR, as discussed in :ref:`dynamic_shapes`. To determine the number of graph breaks in a model, the `torch._dynamo.explain` function is very useful:
90+
91+
.. code-block:: python
92+
93+
import torch
94+
import torch_tensorrt
95+
...
96+
explanation = torch._dynamo.explain(model)(*inputs)
97+
print(f"Graph breaks: {explanation.graph_break_count}")
98+
optimized_model = torch.compile(model, backend="torch_tensorrt", dynamic=False, options={"truncate_long_and_double": True})
99+
100+
Dynamic Shape Support
101+
-----------------
102+
103+
The Torch-TensorRT `torch.compile` backend will currently require recompilation for each new batch size encountered, and it is preferred to use the `dynamic=False` argument when compiling with this backend. Full dynamic shape support is planned for a future release.
104+
105+
Recompilation Conditions
106+
-----------------
107+
108+
Once the model has been compiled, subsequent inference inputs with the same shape and data type, which traverse the graph in the same way will not require recompilation. Furthermore, each new recompilation will be cached for the duration of the Python session. For instance, if inputs of batch size 4 and 8 are provided to the model, causing two recompilations, no further recompilation would be necessary for future inputs with those batch sizes during inference within the same session. Support for engine cache serialization is planned for a future release.
109+
110+
Recompilation is generally triggered by one of two events: encountering inputs of different sizes or inputs which traverse the model code differently. The latter scenario can occur when the model code includes conditional logic, complex loops, or data-dependent-shapes. `torch.compile` handles guarding in both of these scenario and determines when recompilation is necessary.

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,13 @@ class CompilationSettings:
3939
use_python_runtime (Optional[bool]): Whether to strictly use Python runtime or C++ runtime. To auto-select a runtime
4040
based on C++ dependency presence (preferentially choosing C++ runtime if available), leave the
4141
argument as None
42-
truncate_long_and_double (bool): Truncate int64/float64 TRT engine inputs or weights to int32/float32
42+
truncate_long_and_double (bool): Whether to truncate int64/float64 TRT engine inputs or weights to int32/float32
43+
use_fast_partitioner (bool): Whether to use the fast or global graph partitioning system
4344
enable_experimental_decompositions (bool): Whether to enable all core aten decompositions
4445
or only a selected subset of them
46+
device (Device): GPU to compile the model on
47+
require_full_compilation (bool): Whether to require the graph is fully compiled in TensorRT.
48+
Only applicable for `ir="dynamo"`; has no effect for `torch.compile` path
4549
"""
4650

4751
precision: torch.dtype = PRECISION

0 commit comments

Comments
 (0)