Skip to content

Commit 6d2a26b

Browse files
committed
fix: Unify export/compile compilation utilities
- Update argument naming for compatibility with `fx_ts_compat` utilities - Add support for new TRT 8.6 utilities, including auxiliary streams, version compatibility, and optimization levels - Add support for TRTModuleNext use during compilation with Dynamo compile - Improve documentation of features and version checking for TRT feature compatibility - Add test cases for new `TRTModuleNext` functionality and for TRT custom options functionality
1 parent 0f712bd commit 6d2a26b

File tree

9 files changed

+255
-18
lines changed

9 files changed

+255
-18
lines changed

examples/dynamo/dynamo_compile_advanced_usage.py

+2
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
6666
debug=True,
6767
min_block_size=2,
6868
torch_executed_ops={},
69+
optimization_level=4,
70+
use_experimental_rt=True,
6971
)
7072

7173
# Run the model on an input to cause compilation, as so:

py/torch_tensorrt/dynamo/backend/__init__.py

+33-6
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch_tensorrt
55
from functools import partial
66

7-
from typing import Any, Sequence
7+
from typing import Any, Optional, Sequence
88
from torch_tensorrt import EngineCapability, Device
99
from torch_tensorrt.fx.utils import LowerPrecision
1010

@@ -14,9 +14,13 @@
1414
from torch_tensorrt.dynamo.backend._defaults import (
1515
PRECISION,
1616
DEBUG,
17-
MAX_WORKSPACE_SIZE,
17+
WORKSPACE_SIZE,
1818
MIN_BLOCK_SIZE,
1919
PASS_THROUGH_BUILD_FAILURES,
20+
MAX_AUX_STREAMS,
21+
VERSION_COMPATIBLE,
22+
OPTIMIZATION_LEVEL,
23+
USE_EXPERIMENTAL_RT,
2024
)
2125

2226

@@ -35,7 +39,7 @@ def compile(
3539
debug=DEBUG,
3640
capability=EngineCapability.default,
3741
num_avg_timing_iters=1,
38-
workspace_size=MAX_WORKSPACE_SIZE,
42+
workspace_size=WORKSPACE_SIZE,
3943
dla_sram_size=1048576,
4044
dla_local_dram_size=1073741824,
4145
dla_global_dram_size=536870912,
@@ -45,6 +49,10 @@ def compile(
4549
min_block_size=MIN_BLOCK_SIZE,
4650
torch_executed_ops=[],
4751
torch_executed_modules=[],
52+
max_aux_streams=MAX_AUX_STREAMS,
53+
version_compatible=VERSION_COMPATIBLE,
54+
optimization_level=OPTIMIZATION_LEVEL,
55+
use_experimental_rt=USE_EXPERIMENTAL_RT,
4856
**kwargs,
4957
):
5058
if debug:
@@ -86,6 +94,10 @@ def compile(
8694
workspace_size=workspace_size,
8795
min_block_size=min_block_size,
8896
torch_executed_ops=torch_executed_ops,
97+
max_aux_streams=max_aux_streams,
98+
version_compatible=version_compatible,
99+
optimization_level=optimization_level,
100+
use_experimental_rt=use_experimental_rt,
89101
**kwargs,
90102
)
91103

@@ -105,19 +117,30 @@ def compile(
105117
def create_backend(
106118
precision: LowerPrecision = PRECISION,
107119
debug: bool = DEBUG,
108-
workspace_size: int = MAX_WORKSPACE_SIZE,
120+
workspace_size: int = WORKSPACE_SIZE,
109121
min_block_size: int = MIN_BLOCK_SIZE,
110122
torch_executed_ops: Sequence[str] = set(),
111123
pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES,
124+
max_aux_streams: Optional[int] = MAX_AUX_STREAMS,
125+
version_compatible: bool = VERSION_COMPATIBLE,
126+
optimization_level: Optional[int] = OPTIMIZATION_LEVEL,
127+
use_experimental_rt: bool = USE_EXPERIMENTAL_RT,
112128
**kwargs,
113129
):
114130
"""Create torch.compile backend given specified arguments
115131
116132
Args:
117133
precision:
118134
debug: Whether to print out verbose debugging information
119-
workspace_size: Maximum workspace TRT is allowed to use for the module
120-
precision: Model Layer precision
135+
workspace_size: Workspace TRT is allowed to use for the module (0 is default)
136+
min_block_size: Minimum number of operators per TRT-Engine Block
137+
torch_executed_ops: Sequence of operations to run in Torch, regardless of converter coverage
138+
pass_through_build_failures: Whether to fail on TRT engine build errors (True) or not (False)
139+
max_aux_streams: Maximum number of allowed auxiliary TRT streams for each engine
140+
version_compatible: Provide version forward-compatibility for engine plan files
141+
optimization_level: Builder optimization 0-5, higher levels imply longer build time,
142+
searching for more optimization options. TRT defaults to 3
143+
use_experimental_rt: Whether to use the new experimental TRTModuleNext for TRT engines
121144
Returns:
122145
Backend for torch.compile
123146
"""
@@ -131,6 +154,10 @@ def create_backend(
131154
min_block_size=min_block_size,
132155
torch_executed_ops=torch_executed_ops,
133156
pass_through_build_failures=pass_through_build_failures,
157+
max_aux_streams=max_aux_streams,
158+
version_compatible=version_compatible,
159+
optimization_level=optimization_level,
160+
use_experimental_rt=use_experimental_rt,
134161
)
135162

136163
return partial(

py/torch_tensorrt/dynamo/backend/_defaults.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33

44
PRECISION = LowerPrecision.FP32
55
DEBUG = False
6-
MAX_WORKSPACE_SIZE = 20 << 30
6+
WORKSPACE_SIZE = 0
77
MIN_BLOCK_SIZE = 5
88
PASS_THROUGH_BUILD_FAILURES = False
9+
MAX_AUX_STREAMS = None
10+
VERSION_COMPATIBLE = False
11+
OPTIMIZATION_LEVEL = None
12+
USE_EXPERIMENTAL_RT = False
+11-3
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,29 @@
11
from dataclasses import dataclass, field
2-
from typing import Sequence
2+
from typing import Optional, Sequence
33

44
from torch_tensorrt.fx.utils import LowerPrecision
55
from torch_tensorrt.dynamo.backend._defaults import (
66
PRECISION,
77
DEBUG,
8-
MAX_WORKSPACE_SIZE,
8+
WORKSPACE_SIZE,
99
MIN_BLOCK_SIZE,
1010
PASS_THROUGH_BUILD_FAILURES,
11+
MAX_AUX_STREAMS,
12+
VERSION_COMPATIBLE,
13+
OPTIMIZATION_LEVEL,
14+
USE_EXPERIMENTAL_RT,
1115
)
1216

1317

1418
@dataclass(frozen=True)
1519
class CompilationSettings:
1620
precision: LowerPrecision = PRECISION
1721
debug: bool = DEBUG
18-
workspace_size: int = MAX_WORKSPACE_SIZE
22+
workspace_size: int = WORKSPACE_SIZE
1923
min_block_size: int = MIN_BLOCK_SIZE
2024
torch_executed_ops: Sequence[str] = field(default_factory=set)
2125
pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES
26+
max_aux_streams: Optional[int] = MAX_AUX_STREAMS
27+
version_compatible: bool = VERSION_COMPATIBLE
28+
optimization_level: Optional[int] = OPTIMIZATION_LEVEL
29+
use_experimental_rt: bool = USE_EXPERIMENTAL_RT

py/torch_tensorrt/dynamo/backend/backends.py

+1
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ def _compile_module(
135135
submodule,
136136
submodule_inputs,
137137
settings=settings,
138+
name=name,
138139
)
139140

140141
# Replace FX Module with TRT Module
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from typing import Sequence, Union
22
import torch
3+
import io
34
from torch_tensorrt.fx.trt_module import TRTModule
45
from torch_tensorrt import TRTModuleNext
56
from torch_tensorrt.dynamo.backend._settings import CompilationSettings
6-
from torch_tensorrt.fx.fx2trt import (
7+
from torch_tensorrt.dynamo.fx_ts_compat.fx2trt import (
78
InputTensorSpec,
89
TRTInterpreter,
910
)
@@ -15,30 +16,50 @@ def convert_module(
1516
module: torch.fx.GraphModule,
1617
inputs: Sequence[torch.Tensor],
1718
settings: CompilationSettings = CompilationSettings(),
19+
name: str = "",
1820
) -> Union[TRTModuleNext, TRTModule]:
1921
"""Convert an FX module to a TRT module
2022
Args:
2123
module: FX GraphModule to convert
2224
inputs: Sequence of Tensors representing inputs to the module
2325
settings: Compilation settings
26+
name: TRT engine name
2427
Returns:
2528
TRTModule or TRTModuleNext
2629
"""
27-
interp = TRTInterpreter(
30+
interpreter = TRTInterpreter(
2831
module,
2932
InputTensorSpec.from_tensors(inputs),
3033
explicit_batch_dimension=True,
3134
logger_level=(trt.Logger.VERBOSE if settings.debug else trt.Logger.WARNING),
3235
)
3336

34-
r = interp.run(
35-
max_workspace_size=settings.workspace_size,
37+
interpreter_result = interpreter.run(
38+
workspace_size=settings.workspace_size,
3639
lower_precision=settings.precision,
3740
profiling_verbosity=(
3841
trt.ProfilingVerbosity.VERBOSE
3942
if settings.debug
4043
else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
4144
),
45+
max_aux_streams=settings.max_aux_streams,
46+
version_compatible=settings.version_compatible,
47+
optimization_level=settings.optimization_level,
4248
)
4349

44-
return TRTModule(*r)
50+
if settings.use_experimental_rt:
51+
with io.BytesIO() as engine_bytes:
52+
engine_bytes.write(interpreter_result.engine.serialize())
53+
engine_str = engine_bytes.getvalue()
54+
return TRTModuleNext(
55+
serialized_engine=engine_str,
56+
name=name,
57+
input_binding_names=interpreter_result.input_names,
58+
output_binding_names=interpreter_result.output_names,
59+
)
60+
else:
61+
return TRTModule(
62+
engine=interpreter_result.engine,
63+
input_names=interpreter_result.input_names,
64+
output_names=interpreter_result.output_names,
65+
)

0 commit comments

Comments
 (0)