You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docsrc/user_guide/saving_models.rst
+31-18
Original file line number
Diff line number
Diff line change
@@ -14,14 +14,18 @@ Saving models compiled with Torch-TensorRT varies slightly with the `ir` that ha
14
14
Dynamo IR
15
15
-------------
16
16
17
-
Starting with 2.1 release of Torch-TensorRT, we are switching the default compilation to be dynamo based.
18
-
The output of `ir=dynamo` compilation is a `torch.fx.GraphModule` object. There are two ways to save these objects
17
+
The output type of `ir=dynamo` compilation of Torch-TensorRT is `torch.export.ExportedProgram` object by default.
18
+
In addition, we provide a new parameter `output_format` in the `CompilationSetting` object provided before compilation.
19
+
The `output_format` can take the following options
19
20
20
-
a) Converting to Torchscript
21
+
* `exported_program` (or) `ep` : This is the default. Returns an ExportedProgram
22
+
* `torchscript` (or) `ts` : This returns a TorchScript module
23
+
* `graph_module` (or) `fx` : This returns a torch.fx.GraphModule which can be traced into Torchscript to save to disk.
24
+
25
+
a) Torchscript
21
26
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
22
27
23
-
`torch.fx.GraphModule` objects cannot be serialized directly. Hence we use `torch.jit.trace` to convert this into a `ScriptModule` object which can be saved to disk.
24
-
The following code illustrates this approach.
28
+
If you set the `output_format="torchscript"`, this will return a `ScriptModule` which can be serialized via torch.jit.save
25
29
26
30
.. code-block:: python
27
31
@@ -30,9 +34,9 @@ The following code illustrates this approach.
30
34
31
35
model = MyModel().eval().cuda()
32
36
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
33
-
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs) # Output is a torch.fx.GraphModule
`torch_tensorrt.dynamo.export` inlines the submodules within a GraphModule to their corresponding nodes and stiches all the nodes together.
64
-
This is needed as `torch._export` serialization cannot handle serializing and deserializing of submodules (`call_module` nodes).
65
+
c) GraphModule
66
+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
65
67
66
-
.. note:: This way of saving the models using `ExportedProgram` is experimental. Here is a known issue : https://github.com/pytorch/TensorRT/issues/2341
68
+
We can also return a `torch.fx.GraphModule` object as the output of Torch-TensorRT compilation by setting `output_format="graph_module"`.
69
+
Internally, partitioning, lowering, conversion phases operate using GraphModule objects. These can be either traced into a Torchscript modules or
Copy file name to clipboardExpand all lines: py/torch_tensorrt/_Input.py
+6-6
Original file line number
Diff line number
Diff line change
@@ -28,12 +28,12 @@ class _ShapeMode(Enum):
28
28
STATIC=0
29
29
DYNAMIC=1
30
30
31
-
shape_mode: Optional[
32
-
_ShapeMode
33
-
] =None#: Is input statically or dynamically shaped
34
-
shape: Optional[
35
-
Tuple[int, ...] |Dict[str, Tuple[int, ...]]
36
-
] =None#: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
None#: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
36
+
)
37
37
dtype: _enums.dtype= (
38
38
_enums.dtype.unknown
39
39
) #: The expected data type of the input tensor (default: torch_tensorrt.dtype.float32)
"""Compile a TorchScript module for NVIDIA GPUs using TensorRT
@@ -144,6 +147,7 @@ def compile(
144
147
use_python_runtime: (bool): Return a graph using a pure Python runtime, reduces options for serialization
145
148
use_fast_partitioner: (bool): Use the adjacency based partitioning scheme instead of the global partitioner. Adjacency partitioning is faster but may not be optiminal. Use the global paritioner (``False``) if looking for best performance
146
149
enable_experimental_decompositions (bool): Use the full set of operator decompositions. These decompositions may not be tested but serve to make the grap easier to covert to TensorRT, potentially increasing the amount of graphs run in TensorRT.
150
+
output_format (str): Output format of the result of TRT compilation. Options include "exported_program" (or) "ep" | "torchscript" (or) "ts" | "graph_module" (or) "fx". Default is "exported_program"
147
151
**kwargs: Any,
148
152
Returns:
149
153
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
0 commit comments