Skip to content

Commit 656bb9e

Browse files
authored
feat: support tiling optimization as of TRT 10.8 (#3444)
1 parent e722cc5 commit 656bb9e

File tree

4 files changed

+47
-0
lines changed

4 files changed

+47
-0
lines changed

py/torch_tensorrt/dynamo/_compiler.py

+18
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@ def cross_compile_for_windows(
9696
strip_engine_weights: bool = _defaults.STRIP_ENGINE_WEIGHTS,
9797
immutable_weights: bool = _defaults.IMMUTABLE_WEIGHTS,
9898
enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING,
99+
tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL,
100+
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
99101
**kwargs: Any,
100102
) -> torch.fx.GraphModule:
101103
"""Compile an ExportedProgram module using TensorRT in Linux for Inference in Windows
@@ -169,6 +171,8 @@ def cross_compile_for_windows(
169171
strip_engine_weights (bool): Strip engine weights from the serialized engine. This is useful when the engine is to be deployed in an environment where the weights are not required.
170172
immutable_weights (bool): Build non-refittable engines. This is useful for some layers that are not refittable. If this argument is set to true, `strip_engine_weights` and `refit_identical_engine_weights` will be ignored.
171173
enable_weight_streaming (bool): Enable weight streaming.
174+
tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"].
175+
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
172176
**kwargs: Any,
173177
Returns:
174178
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
@@ -326,6 +330,8 @@ def cross_compile_for_windows(
326330
"immutable_weights": immutable_weights,
327331
"enable_cross_compile_for_windows": True,
328332
"enable_weight_streaming": enable_weight_streaming,
333+
"tiling_optimization_level": tiling_optimization_level,
334+
"l2_limit_for_tiling": l2_limit_for_tiling,
329335
}
330336

331337
# disable the following settings is not supported for cross compilation for windows feature
@@ -413,6 +419,8 @@ def compile(
413419
strip_engine_weights: bool = _defaults.STRIP_ENGINE_WEIGHTS,
414420
immutable_weights: bool = _defaults.IMMUTABLE_WEIGHTS,
415421
enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING,
422+
tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL,
423+
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
416424
**kwargs: Any,
417425
) -> torch.fx.GraphModule:
418426
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
@@ -488,6 +496,8 @@ def compile(
488496
strip_engine_weights (bool): Strip engine weights from the serialized engine. This is useful when the engine is to be deployed in an environment where the weights are not required.
489497
immutable_weights (bool): Build non-refittable engines. This is useful for some layers that are not refittable. If this argument is set to true, `strip_engine_weights` and `refit_identical_engine_weights` will be ignored.
490498
enable_weight_streaming (bool): Enable weight streaming.
499+
tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"].
500+
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
491501
**kwargs: Any,
492502
Returns:
493503
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
@@ -662,6 +672,8 @@ def compile(
662672
"immutable_weights": immutable_weights,
663673
"enable_cross_compile_for_windows": False,
664674
"enable_weight_streaming": enable_weight_streaming,
675+
"tiling_optimization_level": tiling_optimization_level,
676+
"l2_limit_for_tiling": l2_limit_for_tiling,
665677
}
666678

667679
settings = CompilationSettings(**compilation_options)
@@ -950,6 +962,8 @@ def convert_exported_program_to_serialized_trt_engine(
950962
strip_engine_weights: bool = _defaults.STRIP_ENGINE_WEIGHTS,
951963
immutable_weights: bool = _defaults.IMMUTABLE_WEIGHTS,
952964
enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING,
965+
tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL,
966+
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
953967
**kwargs: Any,
954968
) -> bytes:
955969
"""Convert an ExportedProgram to a serialized TensorRT engine
@@ -1013,6 +1027,8 @@ def convert_exported_program_to_serialized_trt_engine(
10131027
strip_engine_weights (bool): Strip engine weights from the serialized engine. This is useful when the engine is to be deployed in an environment where the weights are not required.
10141028
immutable_weights (bool): Build non-refittable engines. This is useful for some layers that are not refittable. If this argument is set to true, `strip_engine_weights` and `refit_identical_engine_weights` will be ignored.
10151029
enable_weight_streaming (bool): Enable weight streaming.
1030+
tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"].
1031+
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
10161032
Returns:
10171033
bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs
10181034
"""
@@ -1129,6 +1145,8 @@ def convert_exported_program_to_serialized_trt_engine(
11291145
"strip_engine_weights": strip_engine_weights,
11301146
"immutable_weights": immutable_weights,
11311147
"enable_weight_streaming": enable_weight_streaming,
1148+
"tiling_optimization_level": tiling_optimization_level,
1149+
"l2_limit_for_tiling": l2_limit_for_tiling,
11321150
}
11331151

11341152
settings = CompilationSettings(**compilation_options)

py/torch_tensorrt/dynamo/_defaults.py

+2
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@
4747
ENABLE_WEIGHT_STREAMING = False
4848
ENABLE_CROSS_COMPILE_FOR_WINDOWS = False
4949
USE_AOT_JOINT_EXPORT = True
50+
TILING_OPTIMIZATION_LEVEL = "none"
51+
L2_LIMIT_FOR_TILING = -1
5052

5153

5254
def default_device() -> Device:

py/torch_tensorrt/dynamo/_settings.py

+8
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
ENGINE_CAPABILITY,
2121
HARDWARE_COMPATIBLE,
2222
IMMUTABLE_WEIGHTS,
23+
L2_LIMIT_FOR_TILING,
2324
LAZY_ENGINE_INIT,
2425
MAX_AUX_STREAMS,
2526
MIN_BLOCK_SIZE,
@@ -31,6 +32,7 @@
3132
REUSE_CACHED_ENGINES,
3233
SPARSE_WEIGHTS,
3334
STRIP_ENGINE_WEIGHTS,
35+
TILING_OPTIMIZATION_LEVEL,
3436
TIMING_CACHE_PATH,
3537
TRUNCATE_DOUBLE,
3638
USE_AOT_JOINT_EXPORT,
@@ -93,6 +95,8 @@ class CompilationSettings:
9395
enable_cross_compile_for_windows (bool): By default this is False means TensorRT engines can only be executed on the same platform where they were built.
9496
True will enable cross-platform compatibility which allows the engine to be built on Linux and run on Windows
9597
use_aot_joint_export (bool): Use aot_export_joint_simple, else wrap backend with AOT_autograd, required for distributed tensors
98+
tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"].
99+
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
96100
"""
97101

98102
enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS)
@@ -134,6 +138,8 @@ class CompilationSettings:
134138
enable_weight_streaming: bool = ENABLE_WEIGHT_STREAMING
135139
enable_cross_compile_for_windows: bool = ENABLE_CROSS_COMPILE_FOR_WINDOWS
136140
use_aot_joint_export: bool = USE_AOT_JOINT_EXPORT
141+
tiling_optimization_level: str = TILING_OPTIMIZATION_LEVEL
142+
l2_limit_for_tiling: int = L2_LIMIT_FOR_TILING
137143

138144

139145
_SETTINGS_TO_BE_ENGINE_INVARIANT = (
@@ -149,6 +155,8 @@ class CompilationSettings:
149155
"strip_engine_weights", # TODO: @Evan to remove this after implementing caching weight-stripped engines as default?
150156
"immutable_weights",
151157
"enable_weight_streaming",
158+
"tiling_optimization_level",
159+
"l2_limit_for_tiling",
152160
)
153161

154162

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

+19
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,25 @@ def _populate_trt_builder_config(
329329
if self.compilation_settings.enable_weight_streaming:
330330
builder_config.set_flag(trt.BuilderFlag.WEIGHT_STREAMING)
331331

332+
if version.parse(trt.__version__) >= version.parse("10.8"):
333+
TilingOptimizationLevel = {
334+
"none": trt.TilingOptimizationLevel.NONE,
335+
"fast": trt.TilingOptimizationLevel.FAST,
336+
"moderate": trt.TilingOptimizationLevel.MODERATE,
337+
"full": trt.TilingOptimizationLevel.FULL,
338+
}
339+
assert (
340+
self.compilation_settings.tiling_optimization_level
341+
in TilingOptimizationLevel
342+
), f"Invalid tiling optimization level: {self.compilation_settings.tiling_optimization_level}. We currently support {TilingOptimizationLevel.keys()}."
343+
builder_config.tiling_optimization_level = TilingOptimizationLevel[
344+
self.compilation_settings.tiling_optimization_level
345+
]
346+
347+
builder_config.l2_limit_for_tiling = (
348+
self.compilation_settings.l2_limit_for_tiling
349+
)
350+
332351
return builder_config
333352

334353
def _create_timing_cache(

0 commit comments

Comments
 (0)