Skip to content

Commit fdc7e3c

Browse files
committed
support tiling optimization
1 parent 7dbd4cb commit fdc7e3c

File tree

5 files changed

+69
-0
lines changed

5 files changed

+69
-0
lines changed

py/torch_tensorrt/dynamo/_compiler.py

+18
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@ def cross_compile_for_windows(
9696
strip_engine_weights: bool = _defaults.STRIP_ENGINE_WEIGHTS,
9797
immutable_weights: bool = _defaults.IMMUTABLE_WEIGHTS,
9898
enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING,
99+
tiling_optimization_level: Optional[int] = _defaults.TILING_OPTIMIZATION_LEVEL,
100+
l2_limit_for_tiling: Optional[int] = _defaults.L2_LIMIT_FOR_TILING,
99101
**kwargs: Any,
100102
) -> torch.fx.GraphModule:
101103
"""Compile an ExportedProgram module using TensorRT in Linux for Inference in Windows
@@ -169,6 +171,8 @@ def cross_compile_for_windows(
169171
strip_engine_weights (bool): Strip engine weights from the serialized engine. This is useful when the engine is to be deployed in an environment where the weights are not required.
170172
immutable_weights (bool): Build non-refittable engines. This is useful for some layers that are not refittable. If this argument is set to true, `strip_engine_weights` and `refit_identical_engine_weights` will be ignored.
171173
enable_weight_streaming (bool): Enable weight streaming.
174+
tiling_optimization_level (Optional[int]): The optimization level of tiling strategies. A Higher level allows TensorRT to spend more time searching for better optimization strategy. (We currently support [0, 1, 2, 3], default is 0)
175+
l2_limit_for_tiling (Optional[int]): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
172176
**kwargs: Any,
173177
Returns:
174178
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
@@ -326,6 +330,8 @@ def cross_compile_for_windows(
326330
"immutable_weights": immutable_weights,
327331
"enable_cross_compile_for_windows": True,
328332
"enable_weight_streaming": enable_weight_streaming,
333+
"tiling_optimization_level": tiling_optimization_level,
334+
"l2_limit_for_tiling": l2_limit_for_tiling,
329335
}
330336

331337
# disable the following settings is not supported for cross compilation for windows feature
@@ -413,6 +419,8 @@ def compile(
413419
strip_engine_weights: bool = _defaults.STRIP_ENGINE_WEIGHTS,
414420
immutable_weights: bool = _defaults.IMMUTABLE_WEIGHTS,
415421
enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING,
422+
tiling_optimization_level: Optional[int] = _defaults.TILING_OPTIMIZATION_LEVEL,
423+
l2_limit_for_tiling: Optional[int] = _defaults.L2_LIMIT_FOR_TILING,
416424
**kwargs: Any,
417425
) -> torch.fx.GraphModule:
418426
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
@@ -488,6 +496,8 @@ def compile(
488496
strip_engine_weights (bool): Strip engine weights from the serialized engine. This is useful when the engine is to be deployed in an environment where the weights are not required.
489497
immutable_weights (bool): Build non-refittable engines. This is useful for some layers that are not refittable. If this argument is set to true, `strip_engine_weights` and `refit_identical_engine_weights` will be ignored.
490498
enable_weight_streaming (bool): Enable weight streaming.
499+
tiling_optimization_level (Optional[int]): The optimization level of tiling strategies. A Higher level allows TensorRT to spend more time searching for better optimization strategy. (We currently support [0, 1, 2, 3], default is 0)
500+
l2_limit_for_tiling (Optional[int]): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
491501
**kwargs: Any,
492502
Returns:
493503
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
@@ -662,6 +672,8 @@ def compile(
662672
"immutable_weights": immutable_weights,
663673
"enable_cross_compile_for_windows": False,
664674
"enable_weight_streaming": enable_weight_streaming,
675+
"tiling_optimization_level": tiling_optimization_level,
676+
"l2_limit_for_tiling": l2_limit_for_tiling,
665677
}
666678

667679
settings = CompilationSettings(**compilation_options)
@@ -950,6 +962,8 @@ def convert_exported_program_to_serialized_trt_engine(
950962
strip_engine_weights: bool = _defaults.STRIP_ENGINE_WEIGHTS,
951963
immutable_weights: bool = _defaults.IMMUTABLE_WEIGHTS,
952964
enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING,
965+
tiling_optimization_level: Optional[int] = _defaults.TILING_OPTIMIZATION_LEVEL,
966+
l2_limit_for_tiling: Optional[int] = _defaults.L2_LIMIT_FOR_TILING,
953967
**kwargs: Any,
954968
) -> bytes:
955969
"""Convert an ExportedProgram to a serialized TensorRT engine
@@ -1013,6 +1027,8 @@ def convert_exported_program_to_serialized_trt_engine(
10131027
strip_engine_weights (bool): Strip engine weights from the serialized engine. This is useful when the engine is to be deployed in an environment where the weights are not required.
10141028
immutable_weights (bool): Build non-refittable engines. This is useful for some layers that are not refittable. If this argument is set to true, `strip_engine_weights` and `refit_identical_engine_weights` will be ignored.
10151029
enable_weight_streaming (bool): Enable weight streaming.
1030+
tiling_optimization_level (Optional[int]): The optimization level of tiling strategies. A Higher level allows TensorRT to spend more time searching for better optimization strategy. (We currently support [0, 1, 2, 3], default is 0)
1031+
l2_limit_for_tiling (Optional[int]): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
10161032
Returns:
10171033
bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs
10181034
"""
@@ -1129,6 +1145,8 @@ def convert_exported_program_to_serialized_trt_engine(
11291145
"strip_engine_weights": strip_engine_weights,
11301146
"immutable_weights": immutable_weights,
11311147
"enable_weight_streaming": enable_weight_streaming,
1148+
"tiling_optimization_level": tiling_optimization_level,
1149+
"l2_limit_for_tiling": l2_limit_for_tiling,
11321150
}
11331151

11341152
settings = CompilationSettings(**compilation_options)

py/torch_tensorrt/dynamo/_defaults.py

+2
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@
4747
ENABLE_WEIGHT_STREAMING = False
4848
ENABLE_CROSS_COMPILE_FOR_WINDOWS = False
4949
USE_AOT_JOINT_EXPORT = True
50+
TILING_OPTIMIZATION_LEVEL = 0
51+
L2_LIMIT_FOR_TILING = -1
5052

5153

5254
def default_device() -> Device:

py/torch_tensorrt/dynamo/_settings.py

+8
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
ENGINE_CAPABILITY,
2121
HARDWARE_COMPATIBLE,
2222
IMMUTABLE_WEIGHTS,
23+
L2_LIMIT_FOR_TILING,
2324
LAZY_ENGINE_INIT,
2425
MAX_AUX_STREAMS,
2526
MIN_BLOCK_SIZE,
@@ -31,6 +32,7 @@
3132
REUSE_CACHED_ENGINES,
3233
SPARSE_WEIGHTS,
3334
STRIP_ENGINE_WEIGHTS,
35+
TILING_OPTIMIZATION_LEVEL,
3436
TIMING_CACHE_PATH,
3537
TRUNCATE_DOUBLE,
3638
USE_AOT_JOINT_EXPORT,
@@ -93,6 +95,8 @@ class CompilationSettings:
9395
enable_cross_compile_for_windows (bool): By default this is False means TensorRT engines can only be executed on the same platform where they were built.
9496
True will enable cross-platform compatibility which allows the engine to be built on Linux and run on Windows
9597
use_aot_joint_export (bool): Use aot_export_joint_simple, else wrap backend with AOT_autograd, required for distributed tensors
98+
tiling_optimization_level (Optional[int]): The optimization level of tiling strategies. A Higher level allows TensorRT to spend more time searching for better optimization strategy. (We currently support [0, 1, 2, 3], default is 0)
99+
l2_limit_for_tiling (Optional[int]): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
96100
"""
97101

98102
enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS)
@@ -134,6 +138,8 @@ class CompilationSettings:
134138
enable_weight_streaming: bool = ENABLE_WEIGHT_STREAMING
135139
enable_cross_compile_for_windows: bool = ENABLE_CROSS_COMPILE_FOR_WINDOWS
136140
use_aot_joint_export: bool = USE_AOT_JOINT_EXPORT
141+
tiling_optimization_level: Optional[int] = TILING_OPTIMIZATION_LEVEL
142+
l2_limit_for_tiling: Optional[int] = L2_LIMIT_FOR_TILING
137143

138144

139145
_SETTINGS_TO_BE_ENGINE_INVARIANT = (
@@ -149,6 +155,8 @@ class CompilationSettings:
149155
"strip_engine_weights", # TODO: @Evan to remove this after implementing caching weight-stripped engines as default?
150156
"immutable_weights",
151157
"enable_weight_streaming",
158+
"tiling_optimization_level",
159+
"l2_limit_for_tiling",
152160
)
153161

154162

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

+35
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,41 @@ def _populate_trt_builder_config(
329329
if self.compilation_settings.enable_weight_streaming:
330330
builder_config.set_flag(trt.BuilderFlag.WEIGHT_STREAMING)
331331

332+
if version.parse(trt.__version__) >= version.parse("10.8"):
333+
if (
334+
self.compilation_settings.tiling_optimization_level is None
335+
or self.compilation_settings.tiling_optimization_level == 0
336+
):
337+
builder_config.tiling_optimization_level = (
338+
trt.TilingOptimizationLevel.NONE
339+
)
340+
elif self.compilation_settings.tiling_optimization_level == 1:
341+
builder_config.tiling_optimization_level = (
342+
trt.TilingOptimizationLevel.FAST
343+
)
344+
elif self.compilation_settings.tiling_optimization_level == 2:
345+
builder_config.tiling_optimization_level = (
346+
trt.TilingOptimizationLevel.MODERATE
347+
)
348+
elif self.compilation_settings.tiling_optimization_level == 3:
349+
builder_config.tiling_optimization_level = (
350+
trt.TilingOptimizationLevel.FULL
351+
)
352+
else:
353+
raise ValueError(
354+
f"Invalid tiling optimization level: {self.compilation_settings.tiling_optimization_level}. A valid value should be in [0, 1, 2, 3]."
355+
)
356+
357+
if (
358+
self.compilation_settings.l2_limit_for_tiling is None
359+
or self.compilation_settings.l2_limit_for_tiling == -1
360+
):
361+
builder_config.l2_limit_for_tiling = -1
362+
else:
363+
builder_config.l2_limit_for_tiling = (
364+
self.compilation_settings.l2_limit_for_tiling
365+
)
366+
332367
return builder_config
333368

334369
def _create_timing_cache(

py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py

+6
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ def __init__(
9292
dryrun: bool = _defaults.DRYRUN,
9393
hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE,
9494
timing_cache_path: str = _defaults.TIMING_CACHE_PATH,
95+
tiling_optimization_level: Optional[int] = _defaults.TILING_OPTIMIZATION_LEVEL,
96+
l2_limit_for_tiling: Optional[int] = _defaults.L2_LIMIT_FOR_TILING,
9597
**kwargs: Any,
9698
) -> None:
9799
"""
@@ -133,6 +135,8 @@ def __init__(
133135
hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer)
134136
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation
135137
lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime.
138+
tiling_optimization_level (Optional[int]): The optimization level of tiling strategies. A Higher level allows TensorRT to spend more time searching for better optimization strategy. (We currently support [0, 1, 2, 3], default is 0)
139+
l2_limit_for_tiling (Optional[int]): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
136140
**kwargs: Any,
137141
Returns:
138142
MutableTorchTensorRTModule
@@ -193,6 +197,8 @@ def __init__(
193197
"dryrun": dryrun,
194198
"hardware_compatible": hardware_compatible,
195199
"timing_cache_path": timing_cache_path,
200+
"tiling_optimization_level": tiling_optimization_level,
201+
"l2_limit_for_tiling": l2_limit_for_tiling,
196202
}
197203
self.arg_dynamic_shapes: Optional[tuple[Any]] = None
198204
self.kwarg_dynamic_shapes: Optional[dict[Any, Any]] = None

0 commit comments

Comments
 (0)