Skip to content

Commit 18a462f

Browse files
authored
chore: Set return type of compilation to ExportedProgram [release/2.2] (#2607)
1 parent 5384ba8 commit 18a462f

28 files changed

+426
-296
lines changed

docsrc/user_guide/saving_models.rst

+31-18
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,18 @@ Saving models compiled with Torch-TensorRT varies slightly with the `ir` that ha
1414
Dynamo IR
1515
-------------
1616

17-
Starting with 2.1 release of Torch-TensorRT, we are switching the default compilation to be dynamo based.
18-
The output of `ir=dynamo` compilation is a `torch.fx.GraphModule` object. There are two ways to save these objects
17+
The output type of `ir=dynamo` compilation of Torch-TensorRT is `torch.export.ExportedProgram` object by default.
18+
In addition, we provide a new parameter `output_format` in the `CompilationSetting` object provided before compilation.
19+
The `output_format` can take the following options
1920

20-
a) Converting to Torchscript
21+
* `exported_program` (or) `ep` : This is the default. Returns an ExportedProgram
22+
* `torchscript` (or) `ts` : This returns a TorchScript module
23+
* `graph_module` (or) `fx` : This returns a torch.fx.GraphModule which can be traced into Torchscript to save to disk.
24+
25+
a) Torchscript
2126
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2227

23-
`torch.fx.GraphModule` objects cannot be serialized directly. Hence we use `torch.jit.trace` to convert this into a `ScriptModule` object which can be saved to disk.
24-
The following code illustrates this approach.
28+
If you set the `output_format="torchscript"`, this will return a `ScriptModule` which can be serialized via torch.jit.save
2529

2630
.. code-block:: python
2731
@@ -30,9 +34,9 @@ The following code illustrates this approach.
3034
3135
model = MyModel().eval().cuda()
3236
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
33-
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs) # Output is a torch.fx.GraphModule
34-
trt_traced_model = torch.jit.trace(trt_gm, inputs)
35-
torch.jit.save(trt_traced_model, "trt_model.ts")
37+
# trt_ts is a torch.jit.ScriptModule object
38+
trt_ts = torch_tensorrt.compile(model, ir="dynamo", inputs, output_format="torchscript")
39+
torch.jit.save(trt_ts, "trt_model.ts")
3640
3741
# Later, you can load it and run inference
3842
model = torch.jit.load("trt_model.ts").cuda()
@@ -41,8 +45,7 @@ The following code illustrates this approach.
4145
b) ExportedProgram
4246
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
4347

44-
`torch.export.ExportedProgram` is a new format introduced in Pytorch 2.1. After we compile a Pytorch module using Torch-TensorRT, the resultant
45-
`torch.fx.GraphModule` along with additional metadata can be used to create `ExportedProgram` which can be saved and loaded from disk.
48+
`torch.export.ExportedProgram`, a new format introduced in Pytorch 2.X is the default return type of Torch-TensorRT compilation.
4649

4750
.. code-block:: python
4851
@@ -51,26 +54,36 @@ b) ExportedProgram
5154
5255
model = MyModel().eval().cuda()
5356
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
54-
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs) # Output is a torch.fx.GraphModule
55-
# Transform and create an exported program
56-
trt_exp_program = torch_tensorrt.dynamo.export(trt_gm, inputs)
57-
torch.export.save(trt_exp_program, "trt_model.ep")
57+
# trt_ep is a torch.export.ExportedProgram object
58+
trt_ep = torch_tensorrt.compile(model, ir="dynamo", inputs)
59+
torch.export.save(trt_ep, "trt_model.ep")
5860
5961
# Later, you can load it and run inference
6062
model = torch.export.load("trt_model.ep")
6163
model(*inputs)
6264
63-
`torch_tensorrt.dynamo.export` inlines the submodules within a GraphModule to their corresponding nodes and stiches all the nodes together.
64-
This is needed as `torch._export` serialization cannot handle serializing and deserializing of submodules (`call_module` nodes).
65+
c) GraphModule
66+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
6567

66-
.. note:: This way of saving the models using `ExportedProgram` is experimental. Here is a known issue : https://github.com/pytorch/TensorRT/issues/2341
68+
We can also return a `torch.fx.GraphModule` object as the output of Torch-TensorRT compilation by setting `output_format="graph_module"`.
69+
Internally, partitioning, lowering, conversion phases operate using GraphModule objects. These can be either traced into a Torchscript modules or
70+
exported into `ExportedProgram` objects
6771

72+
.. code-block:: python
73+
74+
import torch
75+
import torch_tensorrt
76+
77+
model = MyModel().eval().cuda()
78+
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
79+
# trt_gm is a torch.fx.GraphModule object
80+
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs, output_format="graph_module")
6881
6982
Torchscript IR
7083
-------------
7184

7285
In Torch-TensorRT 1.X versions, the primary way to compile and run inference with Torch-TensorRT is using Torchscript IR.
73-
This behavior stays the same in 2.X versions as well.
86+
For `ir=ts`, this behavior stays the same in 2.X versions as well.
7487

7588
.. code-block:: python
7689

examples/int8/training/vgg16/vgg16.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
- [Very Deep Convolutional Networks for Large-Scale Image Recognition](
44
https://arxiv.org/abs/1409.1556) (ICLR 2015)
55
"""
6+
7+
from functools import reduce
8+
69
import torch
710
import torch.nn as nn
811
import torch.nn.functional as F
9-
from functools import reduce
1012

1113

1214
class VGG(nn.Module):

py/torch_tensorrt/_Device.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,14 @@ class Device(object):
3232
allow_gpu_fallback (bool): Whether falling back to GPU if DLA cannot support an op should be allowed
3333
"""
3434

35-
device_type: Optional[
36-
trt.DeviceType
37-
] = None #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
35+
device_type: Optional[trt.DeviceType] = (
36+
None #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
37+
)
3838
gpu_id: int = -1 #: Device ID for target GPU
3939
dla_core: int = -1 #: Core ID for target DLA core
40-
allow_gpu_fallback: bool = False #: Whether falling back to GPU if DLA cannot support an op should be allowed
40+
allow_gpu_fallback: bool = (
41+
False #: Whether falling back to GPU if DLA cannot support an op should be allowed
42+
)
4143

4244
def __init__(self, *args: Any, **kwargs: Any):
4345
"""__init__ Method for torch_tensorrt.Device

py/torch_tensorrt/_Input.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,12 @@ class _ShapeMode(Enum):
2828
STATIC = 0
2929
DYNAMIC = 1
3030

31-
shape_mode: Optional[
32-
_ShapeMode
33-
] = None #: Is input statically or dynamically shaped
34-
shape: Optional[
35-
Tuple[int, ...] | Dict[str, Tuple[int, ...]]
36-
] = None #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
31+
shape_mode: Optional[_ShapeMode] = (
32+
None #: Is input statically or dynamically shaped
33+
)
34+
shape: Optional[Tuple[int, ...] | Dict[str, Tuple[int, ...]]] = (
35+
None #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
36+
)
3737
dtype: _enums.dtype = (
3838
_enums.dtype.unknown
3939
) #: The expected data type of the input tensor (default: torch_tensorrt.dtype.float32)

py/torch_tensorrt/dynamo/_compiler.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
MIN_BLOCK_SIZE,
2828
NUM_AVG_TIMING_ITERS,
2929
OPTIMIZATION_LEVEL,
30+
OUTPUT_FORMAT,
3031
PASS_THROUGH_BUILD_FAILURES,
3132
PRECISION,
3233
REFIT,
@@ -38,6 +39,7 @@
3839
VERSION_COMPATIBLE,
3940
WORKSPACE_SIZE,
4041
)
42+
from torch_tensorrt.dynamo._exporter import export
4143
from torch_tensorrt.dynamo.conversion import (
4244
CompilationSettings,
4345
convert_module,
@@ -88,6 +90,7 @@ def compile(
8890
use_python_runtime: bool = USE_PYTHON_RUNTIME,
8991
use_fast_partitioner: bool = USE_FAST_PARTITIONER,
9092
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
93+
output_format: str = OUTPUT_FORMAT,
9194
**kwargs: Any,
9295
) -> torch.fx.GraphModule:
9396
"""Compile a TorchScript module for NVIDIA GPUs using TensorRT
@@ -144,6 +147,7 @@ def compile(
144147
use_python_runtime: (bool): Return a graph using a pure Python runtime, reduces options for serialization
145148
use_fast_partitioner: (bool): Use the adjacency based partitioning scheme instead of the global partitioner. Adjacency partitioning is faster but may not be optiminal. Use the global paritioner (``False``) if looking for best performance
146149
enable_experimental_decompositions (bool): Use the full set of operator decompositions. These decompositions may not be tested but serve to make the grap easier to covert to TensorRT, potentially increasing the amount of graphs run in TensorRT.
150+
output_format (str): Output format of the result of TRT compilation. Options include "exported_program" (or) "ep" | "torchscript" (or) "ts" | "graph_module" (or) "fx". Default is "exported_program"
147151
**kwargs: Any,
148152
Returns:
149153
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
@@ -200,9 +204,9 @@ def compile(
200204
"device": device,
201205
"workspace_size": workspace_size,
202206
"min_block_size": min_block_size,
203-
"torch_executed_ops": torch_executed_ops
204-
if torch_executed_ops is not None
205-
else set(),
207+
"torch_executed_ops": (
208+
torch_executed_ops if torch_executed_ops is not None else set()
209+
),
206210
"pass_through_build_failures": pass_through_build_failures,
207211
"max_aux_streams": max_aux_streams,
208212
"version_compatible": version_compatible,
@@ -219,11 +223,14 @@ def compile(
219223
"dla_sram_size": dla_sram_size,
220224
"dla_local_dram_size": dla_local_dram_size,
221225
"dla_global_dram_size": dla_global_dram_size,
226+
"output_format": output_format,
222227
}
223228

224229
settings = CompilationSettings(**compilation_options)
225230
logger.info("Compilation Settings: %s\n", settings)
226-
return compile_module(gm, inputs, settings)
231+
trt_gm = compile_module(gm, inputs, settings)
232+
trt_result = export(trt_gm, torch_inputs, output_format)
233+
return trt_result
227234

228235

229236
def compile_module(

py/torch_tensorrt/dynamo/_defaults.py

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False
2525
REFIT = False
2626
REQUIRE_FULL_COMPILATION = False
27+
OUTPUT_FORMAT = "exported_program"
2728

2829

2930
def default_device() -> Device:

0 commit comments

Comments
 (0)