diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index d355cefe77..6928347baa 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -96,6 +96,8 @@ def cross_compile_for_windows( strip_engine_weights: bool = _defaults.STRIP_ENGINE_WEIGHTS, immutable_weights: bool = _defaults.IMMUTABLE_WEIGHTS, enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING, + tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL, + l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile an ExportedProgram module using TensorRT in Linux for Inference in Windows @@ -169,6 +171,8 @@ def cross_compile_for_windows( strip_engine_weights (bool): Strip engine weights from the serialized engine. This is useful when the engine is to be deployed in an environment where the weights are not required. immutable_weights (bool): Build non-refittable engines. This is useful for some layers that are not refittable. If this argument is set to true, `strip_engine_weights` and `refit_identical_engine_weights` will be ignored. enable_weight_streaming (bool): Enable weight streaming. + tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"]. + l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit). **kwargs: Any, Returns: torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT @@ -326,6 +330,8 @@ def cross_compile_for_windows( "immutable_weights": immutable_weights, "enable_cross_compile_for_windows": True, "enable_weight_streaming": enable_weight_streaming, + "tiling_optimization_level": tiling_optimization_level, + "l2_limit_for_tiling": l2_limit_for_tiling, } # disable the following settings is not supported for cross compilation for windows feature @@ -413,6 +419,8 @@ def compile( strip_engine_weights: bool = _defaults.STRIP_ENGINE_WEIGHTS, immutable_weights: bool = _defaults.IMMUTABLE_WEIGHTS, enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING, + tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL, + l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile an ExportedProgram module for NVIDIA GPUs using TensorRT @@ -488,6 +496,8 @@ def compile( strip_engine_weights (bool): Strip engine weights from the serialized engine. This is useful when the engine is to be deployed in an environment where the weights are not required. immutable_weights (bool): Build non-refittable engines. This is useful for some layers that are not refittable. If this argument is set to true, `strip_engine_weights` and `refit_identical_engine_weights` will be ignored. enable_weight_streaming (bool): Enable weight streaming. + tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"]. + l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit). **kwargs: Any, Returns: torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT @@ -662,6 +672,8 @@ def compile( "immutable_weights": immutable_weights, "enable_cross_compile_for_windows": False, "enable_weight_streaming": enable_weight_streaming, + "tiling_optimization_level": tiling_optimization_level, + "l2_limit_for_tiling": l2_limit_for_tiling, } settings = CompilationSettings(**compilation_options) @@ -950,6 +962,8 @@ def convert_exported_program_to_serialized_trt_engine( strip_engine_weights: bool = _defaults.STRIP_ENGINE_WEIGHTS, immutable_weights: bool = _defaults.IMMUTABLE_WEIGHTS, enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING, + tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL, + l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING, **kwargs: Any, ) -> bytes: """Convert an ExportedProgram to a serialized TensorRT engine @@ -1013,6 +1027,8 @@ def convert_exported_program_to_serialized_trt_engine( strip_engine_weights (bool): Strip engine weights from the serialized engine. This is useful when the engine is to be deployed in an environment where the weights are not required. immutable_weights (bool): Build non-refittable engines. This is useful for some layers that are not refittable. If this argument is set to true, `strip_engine_weights` and `refit_identical_engine_weights` will be ignored. enable_weight_streaming (bool): Enable weight streaming. + tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"]. + l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit). Returns: bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs """ @@ -1129,6 +1145,8 @@ def convert_exported_program_to_serialized_trt_engine( "strip_engine_weights": strip_engine_weights, "immutable_weights": immutable_weights, "enable_weight_streaming": enable_weight_streaming, + "tiling_optimization_level": tiling_optimization_level, + "l2_limit_for_tiling": l2_limit_for_tiling, } settings = CompilationSettings(**compilation_options) diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 18932e6cd0..ba404a4102 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -47,6 +47,8 @@ ENABLE_WEIGHT_STREAMING = False ENABLE_CROSS_COMPILE_FOR_WINDOWS = False USE_AOT_JOINT_EXPORT = True +TILING_OPTIMIZATION_LEVEL = "none" +L2_LIMIT_FOR_TILING = -1 def default_device() -> Device: diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index 05fb5ce094..fc23ad76cf 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -20,6 +20,7 @@ ENGINE_CAPABILITY, HARDWARE_COMPATIBLE, IMMUTABLE_WEIGHTS, + L2_LIMIT_FOR_TILING, LAZY_ENGINE_INIT, MAX_AUX_STREAMS, MIN_BLOCK_SIZE, @@ -31,6 +32,7 @@ REUSE_CACHED_ENGINES, SPARSE_WEIGHTS, STRIP_ENGINE_WEIGHTS, + TILING_OPTIMIZATION_LEVEL, TIMING_CACHE_PATH, TRUNCATE_DOUBLE, USE_AOT_JOINT_EXPORT, @@ -93,6 +95,8 @@ class CompilationSettings: enable_cross_compile_for_windows (bool): By default this is False means TensorRT engines can only be executed on the same platform where they were built. True will enable cross-platform compatibility which allows the engine to be built on Linux and run on Windows use_aot_joint_export (bool): Use aot_export_joint_simple, else wrap backend with AOT_autograd, required for distributed tensors + tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"]. + l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit). """ enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS) @@ -134,6 +138,8 @@ class CompilationSettings: enable_weight_streaming: bool = ENABLE_WEIGHT_STREAMING enable_cross_compile_for_windows: bool = ENABLE_CROSS_COMPILE_FOR_WINDOWS use_aot_joint_export: bool = USE_AOT_JOINT_EXPORT + tiling_optimization_level: str = TILING_OPTIMIZATION_LEVEL + l2_limit_for_tiling: int = L2_LIMIT_FOR_TILING _SETTINGS_TO_BE_ENGINE_INVARIANT = ( @@ -149,6 +155,8 @@ class CompilationSettings: "strip_engine_weights", # TODO: @Evan to remove this after implementing caching weight-stripped engines as default? "immutable_weights", "enable_weight_streaming", + "tiling_optimization_level", + "l2_limit_for_tiling", ) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 7f26a7c3e6..248e06bc3c 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -329,6 +329,25 @@ def _populate_trt_builder_config( if self.compilation_settings.enable_weight_streaming: builder_config.set_flag(trt.BuilderFlag.WEIGHT_STREAMING) + if version.parse(trt.__version__) >= version.parse("10.8"): + TilingOptimizationLevel = { + "none": trt.TilingOptimizationLevel.NONE, + "fast": trt.TilingOptimizationLevel.FAST, + "moderate": trt.TilingOptimizationLevel.MODERATE, + "full": trt.TilingOptimizationLevel.FULL, + } + assert ( + self.compilation_settings.tiling_optimization_level + in TilingOptimizationLevel + ), f"Invalid tiling optimization level: {self.compilation_settings.tiling_optimization_level}. We currently support {TilingOptimizationLevel.keys()}." + builder_config.tiling_optimization_level = TilingOptimizationLevel[ + self.compilation_settings.tiling_optimization_level + ] + + builder_config.l2_limit_for_tiling = ( + self.compilation_settings.l2_limit_for_tiling + ) + return builder_config def _create_timing_cache(