diff --git a/py/torch_tensorrt/fx/fx2trt.py b/py/torch_tensorrt/fx/fx2trt.py index 7deed3e470..ca16e2ad9b 100644 --- a/py/torch_tensorrt/fx/fx2trt.py +++ b/py/torch_tensorrt/fx/fx2trt.py @@ -163,6 +163,7 @@ def run( algorithm_selector=None, timing_cache=None, profiling_verbosity=None, + tactic_sources=None, ) -> TRTInterpreterResult: """ Build TensorRT engine with some configs. @@ -245,6 +246,9 @@ def run( builder_config.set_flag(trt.BuilderFlag.DISABLE_TIMING_CACHE) builder_config.algorithm_selector = algorithm_selector + if tactic_sources is not None: + builder_config.set_tactic_sources(tactic_sources=tactic_sources) + engine = self.builder.build_engine(self.network, builder_config) assert engine diff --git a/py/torch_tensorrt/fx/lower.py b/py/torch_tensorrt/fx/lower.py index 470f78c407..387b4db841 100644 --- a/py/torch_tensorrt/fx/lower.py +++ b/py/torch_tensorrt/fx/lower.py @@ -120,6 +120,7 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult: profiling_verbosity=trt.ProfilingVerbosity.DETAILED if self.lower_setting.verbose_profile else trt.ProfilingVerbosity.LAYER_NAMES_ONLY, + tactic_sources=self.lower_setting.tactic_sources, ) # Update timing cache file if needed diff --git a/py/torch_tensorrt/fx/lower_setting.py b/py/torch_tensorrt/fx/lower_setting.py index d3f2cc9a14..c1d02229e3 100644 --- a/py/torch_tensorrt/fx/lower_setting.py +++ b/py/torch_tensorrt/fx/lower_setting.py @@ -1,5 +1,5 @@ import dataclasses as dc -from typing import List, Optional, Sequence, Set, Type +from typing import List, Optional, Set, Type from torch import nn from torch.fx.passes.pass_manager import PassManager @@ -68,6 +68,8 @@ class LowerSetting(LowerSettingBasic): opt_profile_replica (int): the number of opt profile set for TensorRT engine, this field is only used by explicit batch dim with dynamic shape mode. dynamic_batch: enable the dynamic shape in TRT with dim=-1 for the 1st dimension. + tactic_sources: tactic sources for TensorRT kernel selection. Default to None, + meaning all possible tactic sources. """ input_specs: List[InputTensorSpec] = dc.field(default_factory=list) @@ -87,3 +89,4 @@ class LowerSetting(LowerSettingBasic): preset_lowerer: str = "" opt_profile_replica: int = 1 dynamic_batch: bool = True + tactic_sources: Optional[int] = None diff --git a/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py b/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py index 937737b60d..047ceb3ad2 100644 --- a/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py +++ b/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py @@ -148,7 +148,7 @@ def lower_func(split_result: SplitResult) -> nn.Module: # Only acc submodules will be lowered. if not submod_name.startswith(split_result.non_acc_submodule_prefix): - _LOGGER.info("Now lowering submodule", submod_name) + _LOGGER.info(f"Now lowering submodule {submod_name}") lowering_start_time = datetime.datetime.now() self.lower_setting.input_specs = generate_input_specs( @@ -166,8 +166,7 @@ def lower_func(split_result: SplitResult) -> nn.Module: submod_name, lowered_module, submod_inputs ) _LOGGER.info( - f"Lowering submodule {submod_name} elapsed time", - datetime.datetime.now() - lowering_start_time, + f"Lowering submodule {submod_name} elapsed time {datetime.datetime.now() - lowering_start_time}" ) return split_result.split_module @@ -184,7 +183,7 @@ def lower_func(split_result: SplitResult) -> nn.Module: # Only acc submodules will be lowered. if not submod_name.startswith(split_result.non_acc_submodule_prefix): - _LOGGER.info("Now lowering submodule", submod_name) + _LOGGER.info(f"Now lowering submodule {submod_name}") lowering_start_time = datetime.datetime.now() lowered_module = self._lower_func( @@ -195,8 +194,7 @@ def lower_func(split_result: SplitResult) -> nn.Module: submod_name, lowered_module, submod_inputs ) _LOGGER.info( - f"Lowering submodule {submod_name} elapsed time", - datetime.datetime.now() - lowering_start_time, + f"Lowering submodule {submod_name} elapsed time {datetime.datetime.now() - lowering_start_time}" ) return split_result.split_module diff --git a/py/torch_tensorrt/fx/tools/trt_splitter.py b/py/torch_tensorrt/fx/tools/trt_splitter.py index 7fbca8d99a..28279a117d 100644 --- a/py/torch_tensorrt/fx/tools/trt_splitter.py +++ b/py/torch_tensorrt/fx/tools/trt_splitter.py @@ -49,7 +49,7 @@ def __init__(self): # During split, we'll split out the operators that # don't support the batch dim. self.use_implicit_batch_dim: bool = True - self.exclude_support_node_name: set = set() + self.exclude_support_node_name: set = set(self.op_lowering_disallow_list) class TRTSplitter(splitter_base._SplitterBase):