Skip to content

feat: support tiling optimization as of TRT 10.8 #3444

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ def cross_compile_for_windows(
strip_engine_weights: bool = _defaults.STRIP_ENGINE_WEIGHTS,
immutable_weights: bool = _defaults.IMMUTABLE_WEIGHTS,
enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING,
tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL,
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
**kwargs: Any,
) -> torch.fx.GraphModule:
"""Compile an ExportedProgram module using TensorRT in Linux for Inference in Windows
Expand Down Expand Up @@ -169,6 +171,8 @@ def cross_compile_for_windows(
strip_engine_weights (bool): Strip engine weights from the serialized engine. This is useful when the engine is to be deployed in an environment where the weights are not required.
immutable_weights (bool): Build non-refittable engines. This is useful for some layers that are not refittable. If this argument is set to true, `strip_engine_weights` and `refit_identical_engine_weights` will be ignored.
enable_weight_streaming (bool): Enable weight streaming.
tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"].
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
**kwargs: Any,
Returns:
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
Expand Down Expand Up @@ -326,6 +330,8 @@ def cross_compile_for_windows(
"immutable_weights": immutable_weights,
"enable_cross_compile_for_windows": True,
"enable_weight_streaming": enable_weight_streaming,
"tiling_optimization_level": tiling_optimization_level,
"l2_limit_for_tiling": l2_limit_for_tiling,
}

# disable the following settings is not supported for cross compilation for windows feature
Expand Down Expand Up @@ -413,6 +419,8 @@ def compile(
strip_engine_weights: bool = _defaults.STRIP_ENGINE_WEIGHTS,
immutable_weights: bool = _defaults.IMMUTABLE_WEIGHTS,
enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING,
tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL,
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
**kwargs: Any,
) -> torch.fx.GraphModule:
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
Expand Down Expand Up @@ -488,6 +496,8 @@ def compile(
strip_engine_weights (bool): Strip engine weights from the serialized engine. This is useful when the engine is to be deployed in an environment where the weights are not required.
immutable_weights (bool): Build non-refittable engines. This is useful for some layers that are not refittable. If this argument is set to true, `strip_engine_weights` and `refit_identical_engine_weights` will be ignored.
enable_weight_streaming (bool): Enable weight streaming.
tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"].
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
**kwargs: Any,
Returns:
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
Expand Down Expand Up @@ -662,6 +672,8 @@ def compile(
"immutable_weights": immutable_weights,
"enable_cross_compile_for_windows": False,
"enable_weight_streaming": enable_weight_streaming,
"tiling_optimization_level": tiling_optimization_level,
"l2_limit_for_tiling": l2_limit_for_tiling,
}

settings = CompilationSettings(**compilation_options)
Expand Down Expand Up @@ -950,6 +962,8 @@ def convert_exported_program_to_serialized_trt_engine(
strip_engine_weights: bool = _defaults.STRIP_ENGINE_WEIGHTS,
immutable_weights: bool = _defaults.IMMUTABLE_WEIGHTS,
enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING,
tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL,
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
**kwargs: Any,
) -> bytes:
"""Convert an ExportedProgram to a serialized TensorRT engine
Expand Down Expand Up @@ -1013,6 +1027,8 @@ def convert_exported_program_to_serialized_trt_engine(
strip_engine_weights (bool): Strip engine weights from the serialized engine. This is useful when the engine is to be deployed in an environment where the weights are not required.
immutable_weights (bool): Build non-refittable engines. This is useful for some layers that are not refittable. If this argument is set to true, `strip_engine_weights` and `refit_identical_engine_weights` will be ignored.
enable_weight_streaming (bool): Enable weight streaming.
tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"].
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
Returns:
bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs
"""
Expand Down Expand Up @@ -1129,6 +1145,8 @@ def convert_exported_program_to_serialized_trt_engine(
"strip_engine_weights": strip_engine_weights,
"immutable_weights": immutable_weights,
"enable_weight_streaming": enable_weight_streaming,
"tiling_optimization_level": tiling_optimization_level,
"l2_limit_for_tiling": l2_limit_for_tiling,
}

settings = CompilationSettings(**compilation_options)
Expand Down
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
ENABLE_WEIGHT_STREAMING = False
ENABLE_CROSS_COMPILE_FOR_WINDOWS = False
USE_AOT_JOINT_EXPORT = True
TILING_OPTIMIZATION_LEVEL = "none"
L2_LIMIT_FOR_TILING = -1


def default_device() -> Device:
Expand Down
8 changes: 8 additions & 0 deletions py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
ENGINE_CAPABILITY,
HARDWARE_COMPATIBLE,
IMMUTABLE_WEIGHTS,
L2_LIMIT_FOR_TILING,
LAZY_ENGINE_INIT,
MAX_AUX_STREAMS,
MIN_BLOCK_SIZE,
Expand All @@ -31,6 +32,7 @@
REUSE_CACHED_ENGINES,
SPARSE_WEIGHTS,
STRIP_ENGINE_WEIGHTS,
TILING_OPTIMIZATION_LEVEL,
TIMING_CACHE_PATH,
TRUNCATE_DOUBLE,
USE_AOT_JOINT_EXPORT,
Expand Down Expand Up @@ -93,6 +95,8 @@ class CompilationSettings:
enable_cross_compile_for_windows (bool): By default this is False means TensorRT engines can only be executed on the same platform where they were built.
True will enable cross-platform compatibility which allows the engine to be built on Linux and run on Windows
use_aot_joint_export (bool): Use aot_export_joint_simple, else wrap backend with AOT_autograd, required for distributed tensors
tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"].
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
"""

enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS)
Expand Down Expand Up @@ -134,6 +138,8 @@ class CompilationSettings:
enable_weight_streaming: bool = ENABLE_WEIGHT_STREAMING
enable_cross_compile_for_windows: bool = ENABLE_CROSS_COMPILE_FOR_WINDOWS
use_aot_joint_export: bool = USE_AOT_JOINT_EXPORT
tiling_optimization_level: str = TILING_OPTIMIZATION_LEVEL
l2_limit_for_tiling: int = L2_LIMIT_FOR_TILING


_SETTINGS_TO_BE_ENGINE_INVARIANT = (
Expand All @@ -149,6 +155,8 @@ class CompilationSettings:
"strip_engine_weights", # TODO: @Evan to remove this after implementing caching weight-stripped engines as default?
"immutable_weights",
"enable_weight_streaming",
"tiling_optimization_level",
"l2_limit_for_tiling",
)


Expand Down
19 changes: 19 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,25 @@ def _populate_trt_builder_config(
if self.compilation_settings.enable_weight_streaming:
builder_config.set_flag(trt.BuilderFlag.WEIGHT_STREAMING)

if version.parse(trt.__version__) >= version.parse("10.8"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can just drop 10.7 instead having this piecemeal support

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like we do it for some settings but not others, so we need to decide if we want versioned builder config or not

Copy link
Collaborator

@narendasan narendasan Mar 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my default stance is no but if its not too much work (outside of 2.7 scope) then we might want to in which case this can stay

TilingOptimizationLevel = {
"none": trt.TilingOptimizationLevel.NONE,
"fast": trt.TilingOptimizationLevel.FAST,
"moderate": trt.TilingOptimizationLevel.MODERATE,
"full": trt.TilingOptimizationLevel.FULL,
}
assert (
self.compilation_settings.tiling_optimization_level
in TilingOptimizationLevel
), f"Invalid tiling optimization level: {self.compilation_settings.tiling_optimization_level}. We currently support {TilingOptimizationLevel.keys()}."
builder_config.tiling_optimization_level = TilingOptimizationLevel[
self.compilation_settings.tiling_optimization_level
]

builder_config.l2_limit_for_tiling = (
self.compilation_settings.l2_limit_for_tiling
)
Comment on lines +347 to +349
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you want to be really safe (when we remove version guarding), you can check if self.compilation_settings.get("l2_limit_for_tiling", -1) != -1 or something.


return builder_config

def _create_timing_cache(
Expand Down