Skip to content

Commit 6338bd5

Browse files
authored
fix/feat: Add support for multiple TRT Build Args (#2510)
1 parent 1ff10a6 commit 6338bd5

File tree

7 files changed

+206
-56
lines changed

7 files changed

+206
-56
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,21 @@
1616
from torch_tensorrt.dynamo._defaults import (
1717
DEBUG,
1818
DEVICE,
19+
DISABLE_TF32,
20+
DLA_GLOBAL_DRAM_SIZE,
21+
DLA_LOCAL_DRAM_SIZE,
22+
DLA_SRAM_SIZE,
1923
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
24+
ENGINE_CAPABILITY,
2025
MAX_AUX_STREAMS,
2126
MIN_BLOCK_SIZE,
27+
NUM_AVG_TIMING_ITERS,
2228
OPTIMIZATION_LEVEL,
2329
PASS_THROUGH_BUILD_FAILURES,
2430
PRECISION,
31+
REFIT,
2532
REQUIRE_FULL_COMPILATION,
33+
SPARSE_WEIGHTS,
2634
TRUNCATE_LONG_AND_DOUBLE,
2735
USE_FAST_PARTITIONER,
2836
USE_PYTHON_RUNTIME,
@@ -51,17 +59,18 @@ def compile(
5159
inputs: Tuple[Any, ...],
5260
*,
5361
device: Optional[Union[Device, torch.device, str]] = DEVICE,
54-
disable_tf32: bool = False,
55-
sparse_weights: bool = False,
62+
disable_tf32: bool = DISABLE_TF32,
63+
sparse_weights: bool = SPARSE_WEIGHTS,
5664
enabled_precisions: Set[torch.dtype] | Tuple[torch.dtype] = (torch.float32,),
57-
refit: bool = False,
65+
engine_capability: EngineCapability = ENGINE_CAPABILITY,
66+
refit: bool = REFIT,
5867
debug: bool = DEBUG,
5968
capability: EngineCapability = EngineCapability.default,
60-
num_avg_timing_iters: int = 1,
69+
num_avg_timing_iters: int = NUM_AVG_TIMING_ITERS,
6170
workspace_size: int = WORKSPACE_SIZE,
62-
dla_sram_size: int = 1048576,
63-
dla_local_dram_size: int = 1073741824,
64-
dla_global_dram_size: int = 536870912,
71+
dla_sram_size: int = DLA_SRAM_SIZE,
72+
dla_local_dram_size: int = DLA_LOCAL_DRAM_SIZE,
73+
dla_global_dram_size: int = DLA_GLOBAL_DRAM_SIZE,
6574
calibrator: object = None,
6675
truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE,
6776
require_full_compilation: bool = REQUIRE_FULL_COMPILATION,
@@ -199,6 +208,13 @@ def compile(
199208
"use_fast_partitioner": use_fast_partitioner,
200209
"enable_experimental_decompositions": enable_experimental_decompositions,
201210
"require_full_compilation": require_full_compilation,
211+
"disable_tf32": disable_tf32,
212+
"sparse_weights": sparse_weights,
213+
"refit": refit,
214+
"engine_capability": engine_capability,
215+
"dla_sram_size": dla_sram_size,
216+
"dla_local_dram_size": dla_local_dram_size,
217+
"dla_global_dram_size": dla_global_dram_size,
202218
}
203219

204220
settings = CompilationSettings(**compilation_options)

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,28 @@
11
import torch
2+
from tensorrt import EngineCapability
23
from torch_tensorrt._Device import Device
34

45
PRECISION = torch.float32
56
DEBUG = False
67
DEVICE = None
8+
DISABLE_TF32 = False
9+
DLA_LOCAL_DRAM_SIZE = 1073741824
10+
DLA_GLOBAL_DRAM_SIZE = 536870912
11+
DLA_SRAM_SIZE = 1048576
12+
ENGINE_CAPABILITY = EngineCapability.STANDARD
713
WORKSPACE_SIZE = 0
814
MIN_BLOCK_SIZE = 5
915
PASS_THROUGH_BUILD_FAILURES = False
1016
MAX_AUX_STREAMS = None
17+
NUM_AVG_TIMING_ITERS = 1
1118
VERSION_COMPATIBLE = False
1219
OPTIMIZATION_LEVEL = None
20+
SPARSE_WEIGHTS = False
1321
TRUNCATE_LONG_AND_DOUBLE = False
1422
USE_PYTHON_RUNTIME = False
1523
USE_FAST_PARTITIONER = True
1624
ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False
25+
REFIT = False
1726
REQUIRE_FULL_COMPILATION = False
1827

1928

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,25 @@
22
from typing import Optional, Set
33

44
import torch
5+
from tensorrt import EngineCapability
56
from torch_tensorrt._Device import Device
67
from torch_tensorrt.dynamo._defaults import (
78
DEBUG,
9+
DISABLE_TF32,
10+
DLA_GLOBAL_DRAM_SIZE,
11+
DLA_LOCAL_DRAM_SIZE,
12+
DLA_SRAM_SIZE,
813
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
14+
ENGINE_CAPABILITY,
915
MAX_AUX_STREAMS,
1016
MIN_BLOCK_SIZE,
17+
NUM_AVG_TIMING_ITERS,
1118
OPTIMIZATION_LEVEL,
1219
PASS_THROUGH_BUILD_FAILURES,
1320
PRECISION,
21+
REFIT,
1422
REQUIRE_FULL_COMPILATION,
23+
SPARSE_WEIGHTS,
1524
TRUNCATE_LONG_AND_DOUBLE,
1625
USE_FAST_PARTITIONER,
1726
USE_PYTHON_RUNTIME,
@@ -46,6 +55,14 @@ class CompilationSettings:
4655
device (Device): GPU to compile the model on
4756
require_full_compilation (bool): Whether to require the graph is fully compiled in TensorRT.
4857
Only applicable for `ir="dynamo"`; has no effect for `torch.compile` path
58+
disable_tf32 (bool): Whether to disable TF32 computation for TRT layers
59+
sparse_weights (bool): Whether to allow the builder to use sparse weights
60+
refit (bool): Whether to build a refittable engine
61+
engine_capability (trt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels
62+
num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels
63+
dla_sram_size (int): Fast software managed RAM used by DLA to communicate within a layer.
64+
dla_local_dram_size (int): Host RAM used by DLA to share intermediate tensor data across operations
65+
dla_global_dram_size (int): Host RAM used by DLA to store weights and metadata for execution
4966
"""
5067

5168
precision: torch.dtype = PRECISION
@@ -63,3 +80,11 @@ class CompilationSettings:
6380
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS
6481
device: Device = field(default_factory=default_device)
6582
require_full_compilation: bool = REQUIRE_FULL_COMPILATION
83+
disable_tf32: bool = DISABLE_TF32
84+
sparse_weights: bool = SPARSE_WEIGHTS
85+
refit: bool = REFIT
86+
engine_capability: EngineCapability = ENGINE_CAPABILITY
87+
num_avg_timing_iters: int = NUM_AVG_TIMING_ITERS
88+
dla_sram_size: int = DLA_SRAM_SIZE
89+
dla_local_dram_size: int = DLA_LOCAL_DRAM_SIZE
90+
dla_global_dram_size: int = DLA_GLOBAL_DRAM_SIZE

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 48 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set
55

66
import numpy as np
7+
import tensorrt as trt
78
import torch
89
import torch.fx
910
from torch.fx.node import _get_qualified_name
@@ -23,8 +24,6 @@
2324
from torch_tensorrt.fx.observer import Observer
2425
from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter
2526

26-
# @manual=//deeplearning/trt/python:py_tensorrt
27-
import tensorrt as trt
2827
from packaging import version
2928

3029
_LOGGER: logging.Logger = logging.getLogger(__name__)
@@ -96,6 +95,7 @@ def __init__(
9695
self._itensor_to_tensor_meta: Dict[
9796
trt.tensorrt.ITensor, TensorMetadata
9897
] = dict()
98+
self.compilation_settings = compilation_settings
9999

100100
# Data types for TRT Module output Tensors
101101
self.output_dtypes = output_dtypes
@@ -118,40 +118,25 @@ def validate_conversion(self) -> Set[str]:
118118

119119
def run(
120120
self,
121-
workspace_size: int = 0,
122-
precision: torch.dtype = torch.float32, # TODO: @peri044 Needs to be expanded to set
123-
sparse_weights: bool = False,
124-
disable_tf32: bool = False,
125121
force_fp32_output: bool = False,
126122
strict_type_constraints: bool = False,
127123
algorithm_selector: Optional[trt.IAlgorithmSelector] = None,
128124
timing_cache: Optional[trt.ITimingCache] = None,
129-
profiling_verbosity: Optional[trt.ProfilingVerbosity] = None,
130125
tactic_sources: Optional[int] = None,
131-
max_aux_streams: Optional[int] = None,
132-
version_compatible: bool = False,
133-
optimization_level: Optional[int] = None,
134126
) -> TRTInterpreterResult:
135127
"""
136128
Build TensorRT engine with some configs.
137129
Args:
138-
workspace_size: Amount of memory used by TensorRT to store intermediate buffers within an operation.
139-
precision: the precision model layers are running on (TensorRT will choose the best perforamnce precision).
140-
sparse_weights: allow the builder to examine weights and use optimized functions when weights have suitable sparsity
141130
force_fp32_output: force output to be fp32
142131
strict_type_constraints: Usually we should set it to False unless we want to control the precision of certain layer for numeric reasons.
143132
algorithm_selector: set up algorithm selection for certain layer
144133
timing_cache: enable timing cache for TensorRT
145-
profiling_verbosity: TensorRT logging level
146-
max_aux_streams: Maximum number of allowed auxiliary TRT streams for each engine
147-
version_compatible: Provide version forward-compatibility for engine plan files
148-
optimization_level: Builder optimization 0-5, higher levels imply longer build time,
149-
searching for more optimization options. TRT defaults to 3
150134
Return:
151135
TRTInterpreterResult
152136
"""
153137
TRT_INTERPRETER_CALL_PRE_OBSERVER.observe(self.module)
154138

139+
precision = self.compilation_settings.precision
155140
# For float outputs, we set their dtype to fp16 only if precision == torch.float16 and
156141
# force_fp32_output=False. Overriden by specifying output_dtypes
157142
self.output_fp16 = not force_fp32_output and precision == torch.float16
@@ -172,9 +157,9 @@ def run(
172157

173158
builder_config = self.builder.create_builder_config()
174159

175-
if workspace_size != 0:
160+
if self.compilation_settings.workspace_size != 0:
176161
builder_config.set_memory_pool_limit(
177-
trt.MemoryPoolType.WORKSPACE, workspace_size
162+
trt.MemoryPoolType.WORKSPACE, self.compilation_settings.workspace_size
178163
)
179164

180165
cache = None
@@ -187,34 +172,66 @@ def run(
187172

188173
if version.parse(trt.__version__) >= version.parse("8.2"):
189174
builder_config.profiling_verbosity = (
190-
profiling_verbosity
191-
if profiling_verbosity
175+
trt.ProfilingVerbosity.VERBOSE
176+
if self.compilation_settings.debug
192177
else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
193178
)
194179

195180
if version.parse(trt.__version__) >= version.parse("8.6"):
196-
if max_aux_streams is not None:
197-
_LOGGER.info(f"Setting max aux streams to {max_aux_streams}")
198-
builder_config.max_aux_streams = max_aux_streams
199-
if version_compatible:
181+
if self.compilation_settings.max_aux_streams is not None:
182+
_LOGGER.info(
183+
f"Setting max aux streams to {self.compilation_settings.max_aux_streams}"
184+
)
185+
builder_config.max_aux_streams = (
186+
self.compilation_settings.max_aux_streams
187+
)
188+
if self.compilation_settings.version_compatible:
200189
_LOGGER.info("Using version compatible")
201190
builder_config.set_flag(trt.BuilderFlag.VERSION_COMPATIBLE)
202-
if optimization_level is not None:
203-
_LOGGER.info(f"Using optimization level {optimization_level}")
204-
builder_config.builder_optimization_level = optimization_level
191+
if self.compilation_settings.optimization_level is not None:
192+
_LOGGER.info(
193+
f"Using optimization level {self.compilation_settings.optimization_level}"
194+
)
195+
builder_config.builder_optimization_level = (
196+
self.compilation_settings.optimization_level
197+
)
198+
199+
builder_config.engine_capability = self.compilation_settings.engine_capability
200+
builder_config.avg_timing_iterations = (
201+
self.compilation_settings.num_avg_timing_iters
202+
)
203+
204+
if self.compilation_settings.device.device_type == trt.DeviceType.DLA:
205+
builder_config.DLA_core = self.compilation_settings.device.dla_core
206+
_LOGGER.info(f"Using DLA core {self.compilation_settings.device.dla_core}")
207+
builder_config.set_memory_pool_limit(
208+
trt.MemoryPoolType.DLA_MANAGED_SRAM,
209+
self.compilation_settings.dla_sram_size,
210+
)
211+
builder_config.set_memory_pool_limit(
212+
trt.MemoryPoolType.DLA_LOCAL_DRAM,
213+
self.compilation_settings.dla_local_dram_size,
214+
)
215+
builder_config.set_memory_pool_limit(
216+
trt.MemoryPoolType.DLA_GLOBAL_DRAM,
217+
self.compilation_settings.dla_global_dram_size,
218+
)
205219

206220
if precision == torch.float16:
207221
builder_config.set_flag(trt.BuilderFlag.FP16)
208222

209223
if precision == torch.int8:
210224
builder_config.set_flag(trt.BuilderFlag.INT8)
211225

212-
if sparse_weights:
226+
if self.compilation_settings.sparse_weights:
213227
builder_config.set_flag(trt.BuilderFlag.SPARSE_WEIGHTS)
214228

215-
if disable_tf32:
229+
if self.compilation_settings.disable_tf32:
216230
builder_config.clear_flag(trt.BuilderFlag.TF32)
217231

232+
if self.compilation_settings.refit:
233+
builder_config.set_flag(trt.BuilderFlag.REFIT)
234+
218235
if strict_type_constraints:
219236
builder_config.set_flag(trt.BuilderFlag.STRICT_TYPES)
220237

py/torch_tensorrt/dynamo/conversion/_conversion.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,14 @@
33
import io
44
from typing import Sequence
55

6+
import tensorrt as trt
67
import torch
78
from torch_tensorrt._Input import Input
89
from torch_tensorrt.dynamo._settings import CompilationSettings
910
from torch_tensorrt.dynamo.conversion._TRTInterpreter import TRTInterpreter
1011
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
1112
from torch_tensorrt.dynamo.utils import get_torch_inputs
1213

13-
import tensorrt as trt
14-
1514

1615
def convert_module(
1716
module: torch.fx.GraphModule,
@@ -54,18 +53,7 @@ def convert_module(
5453
output_dtypes=output_dtypes,
5554
compilation_settings=settings,
5655
)
57-
interpreter_result = interpreter.run(
58-
workspace_size=settings.workspace_size,
59-
precision=settings.precision,
60-
profiling_verbosity=(
61-
trt.ProfilingVerbosity.VERBOSE
62-
if settings.debug
63-
else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
64-
),
65-
max_aux_streams=settings.max_aux_streams,
66-
version_compatible=settings.version_compatible,
67-
optimization_level=settings.optimization_level,
68-
)
56+
interpreter_result = interpreter.run()
6957

7058
if settings.use_python_runtime:
7159
return PythonTorchTensorRTModule(

tests/py/dynamo/conversion/harness.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@ def run_test(
5050
interpreter,
5151
rtol,
5252
atol,
53-
precision=torch.float,
5453
check_dtype=True,
5554
):
5655
with torch.no_grad():
@@ -60,7 +59,7 @@ def run_test(
6059

6160
mod.eval()
6261
start = time.perf_counter()
63-
interpreter_result = interpreter.run(precision=precision)
62+
interpreter_result = interpreter.run()
6463
sec = time.perf_counter() - start
6564
_LOGGER.info(f"Interpreter run time(s): {sec}")
6665
trt_mod = PythonTorchTensorRTModule(
@@ -234,7 +233,9 @@ def run_test(
234233

235234
# Previous instance of the interpreter auto-casted 64-bit inputs
236235
# We replicate this behavior here
237-
compilation_settings = CompilationSettings(truncate_long_and_double=True)
236+
compilation_settings = CompilationSettings(
237+
precision=precision, truncate_long_and_double=True
238+
)
238239

239240
interp = TRTInterpreter(
240241
mod,
@@ -248,7 +249,6 @@ def run_test(
248249
interp,
249250
rtol,
250251
atol,
251-
precision,
252252
check_dtype,
253253
)
254254

0 commit comments

Comments
 (0)