diff --git a/py/torch_tensorrt/dynamo/backend/__init__.py b/py/torch_tensorrt/dynamo/backend/__init__.py index 0846dec144..3743b263db 100644 --- a/py/torch_tensorrt/dynamo/backend/__init__.py +++ b/py/torch_tensorrt/dynamo/backend/__init__.py @@ -16,6 +16,7 @@ DEBUG, MAX_WORKSPACE_SIZE, MIN_BLOCK_SIZE, + PASS_THROUGH_BUILD_FAILURES, ) @@ -46,11 +47,14 @@ def compile( torch_executed_modules=[], **kwargs, ): + if debug: + logger.setLevel(logging.DEBUG) logger.warn( "The Dynamo backend is an experimental feature, for which only the " + "following arguments are supported: " - + "{enabled_precisions, debug, workspace_size, min_block_size, torch_executed_ops}" + + "{enabled_precisions, debug, workspace_size, min_block_size, " + + "torch_executed_ops, pass_through_build_failures}" ) if not isinstance(inputs, collections.abc.Sequence): @@ -104,6 +108,7 @@ def create_backend( workspace_size: int = MAX_WORKSPACE_SIZE, min_block_size: int = MIN_BLOCK_SIZE, torch_executed_ops: Sequence[str] = set(), + pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES, **kwargs, ): """Create torch.compile backend given specified arguments @@ -116,12 +121,16 @@ def create_backend( Returns: Backend for torch.compile """ + if debug: + logger.setLevel(logging.DEBUG) + settings = CompilationSettings( debug=debug, precision=precision, workspace_size=workspace_size, min_block_size=min_block_size, torch_executed_ops=torch_executed_ops, + pass_through_build_failures=pass_through_build_failures, ) return partial( diff --git a/py/torch_tensorrt/dynamo/backend/_defaults.py b/py/torch_tensorrt/dynamo/backend/_defaults.py index b1ee62dfa3..fe7b5f6b4f 100644 --- a/py/torch_tensorrt/dynamo/backend/_defaults.py +++ b/py/torch_tensorrt/dynamo/backend/_defaults.py @@ -5,3 +5,4 @@ DEBUG = False MAX_WORKSPACE_SIZE = 20 << 30 MIN_BLOCK_SIZE = 5 +PASS_THROUGH_BUILD_FAILURES = False diff --git a/py/torch_tensorrt/dynamo/backend/_settings.py b/py/torch_tensorrt/dynamo/backend/_settings.py index 8c1a807343..df3212f54a 100644 --- a/py/torch_tensorrt/dynamo/backend/_settings.py +++ b/py/torch_tensorrt/dynamo/backend/_settings.py @@ -7,6 +7,7 @@ DEBUG, MAX_WORKSPACE_SIZE, MIN_BLOCK_SIZE, + PASS_THROUGH_BUILD_FAILURES, ) @@ -17,3 +18,4 @@ class CompilationSettings: workspace_size: int = MAX_WORKSPACE_SIZE min_block_size: int = MIN_BLOCK_SIZE torch_executed_ops: Sequence[str] = field(default_factory=set) + pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index 962cbe8eba..8f6408492a 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -1,6 +1,6 @@ +import logging from typing import Sequence import torch -import traceback from functools import partial import torch._dynamo as td @@ -19,6 +19,9 @@ from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler +logger = logging.getLogger(__name__) + + @td.register_backend(name="torch_tensorrt") @fake_tensor_unsupported def torch_tensorrt_backend( @@ -52,6 +55,7 @@ def aot_torch_tensorrt_aten_backend( ) +@fake_tensor_unsupported def _pretraced_backend( gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor], @@ -74,12 +78,22 @@ def _pretraced_backend( ) return trt_compiled except: - traceback.print_exc() - print( + logger.error( "FX2TRT conversion failed on the subgraph. See trace above. " - + "Returning GraphModule forward instead." + + "Returning GraphModule forward instead.", + exc_info=True, ) - return gm.forward + + if not settings.pass_through_build_failures: + return gm.forward + else: + raise AssertionError( + "Halting compilation on build failure since " + + "pass_through_build_failures was specified as True. " + + "To return the default Torch implementation and avoid " + + "halting compilation on engine build failures, " + + "specify pass_through_build_failures=False." + ) def _compile_module( @@ -120,9 +134,7 @@ def _compile_module( trt_mod = convert_module( submodule, submodule_inputs, - debug=settings.debug, - workspace_size=settings.workspace_size, - precision=settings.precision, + settings=settings, ) # Replace FX Module with TRT Module diff --git a/py/torch_tensorrt/dynamo/backend/conversion.py b/py/torch_tensorrt/dynamo/backend/conversion.py index 4f495dad4b..1644dea547 100644 --- a/py/torch_tensorrt/dynamo/backend/conversion.py +++ b/py/torch_tensorrt/dynamo/backend/conversion.py @@ -2,11 +2,11 @@ import torch from torch_tensorrt.fx.trt_module import TRTModule from torch_tensorrt import TRTModuleNext +from torch_tensorrt.dynamo.backend._settings import CompilationSettings from torch_tensorrt.fx.fx2trt import ( InputTensorSpec, TRTInterpreter, ) -from torch_tensorrt.fx.utils import LowerPrecision import tensorrt as trt @@ -14,17 +14,13 @@ def convert_module( module: torch.fx.GraphModule, inputs: Sequence[torch.Tensor], - debug: bool = False, - workspace_size: int = 20 << 30, - precision: LowerPrecision = LowerPrecision.FP32, + settings: CompilationSettings = CompilationSettings(), ) -> Union[TRTModuleNext, TRTModule]: """Convert an FX module to a TRT module Args: module: FX GraphModule to convert inputs: Sequence of Tensors representing inputs to the module - debug: Whether to print out verbose debugging information - workspace_size: Maximum workspace TRT is allowed to use for the module - precision: Model Layer precision + settings: Compilation settings Returns: TRTModule or TRTModuleNext """ @@ -32,15 +28,15 @@ def convert_module( module, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True, - logger_level=(trt.Logger.VERBOSE if debug else trt.Logger.WARNING), + logger_level=(trt.Logger.VERBOSE if settings.debug else trt.Logger.WARNING), ) r = interp.run( - max_workspace_size=workspace_size, - lower_precision=precision, + max_workspace_size=settings.workspace_size, + lower_precision=settings.precision, profiling_verbosity=( trt.ProfilingVerbosity.VERBOSE - if debug + if settings.debug else trt.ProfilingVerbosity.LAYER_NAMES_ONLY ), ) diff --git a/py/torch_tensorrt/dynamo/backend/lowering/_partition.py b/py/torch_tensorrt/dynamo/backend/lowering/_partition.py index b4d1b18db9..5cd83d768c 100644 --- a/py/torch_tensorrt/dynamo/backend/lowering/_partition.py +++ b/py/torch_tensorrt/dynamo/backend/lowering/_partition.py @@ -136,15 +136,18 @@ def print_support_overview(self, num_trt_blocks: Optional[int] = None): f"\nNumber of TensorRT-Accelerated Engines Generated: {num_trt_blocks}" ) - logger.debug("\nSupported Nodes:") + # Reformat support messages for debugger to print node overview as a single string + supported_nodes_str = "\nSupported Nodes:\n" for node_name in self.supported_operators: - logger.debug("-", node_name) + supported_nodes_str += f"- {node_name}\n" + + logger.debug(supported_nodes_str) if len(self.unsupported_operators) != 0: - logger.debug("\nUnsupported or Excluded Nodes:") + unsupported_nodes_str = "\nUnsupported or Excluded Nodes:\n" for node_name in self.unsupported_operators: - logger.debug("-", node_name) - logger.debug("\n") + unsupported_nodes_str += f"- {node_name}\n" + logger.debug(unsupported_nodes_str) else: logger.debug("\nAll Nodes Supported\n") diff --git a/py/torch_tensorrt/dynamo/backend/test/utils.py b/py/torch_tensorrt/dynamo/backend/test/utils.py index d59b710faf..48f6443e32 100644 --- a/py/torch_tensorrt/dynamo/backend/test/utils.py +++ b/py/torch_tensorrt/dynamo/backend/test/utils.py @@ -124,7 +124,7 @@ def lower_graph_testing( torch_executed_ops: Sequence[str] = set(), testing_partitioning: bool = False, ): - """Helper function to assist with graph lowering for testing of Dynamo torch_compile + """Helper function to assist with graph lowering for testing of Dynamo compile Args: fx_graph: Graph to lower diff --git a/py/torch_tensorrt/dynamo/common_utils/__init__.py b/py/torch_tensorrt/dynamo/common_utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/py/torch_tensorrt/dynamo/test/utils.py b/py/torch_tensorrt/dynamo/common_utils/test_utils.py similarity index 100% rename from py/torch_tensorrt/dynamo/test/utils.py rename to py/torch_tensorrt/dynamo/common_utils/test_utils.py diff --git a/py/torch_tensorrt/dynamo/test/test_dynamo_backend.py b/py/torch_tensorrt/dynamo/test/test_dynamo_backend.py index b86817df56..9f2ecf1432 100644 --- a/py/torch_tensorrt/dynamo/test/test_dynamo_backend.py +++ b/py/torch_tensorrt/dynamo/test/test_dynamo_backend.py @@ -7,7 +7,10 @@ from transformers import BertModel -from utils import COSINE_THRESHOLD, cosine_similarity +from torch_tensorrt.dynamo.common_utils.test_utils import ( + COSINE_THRESHOLD, + cosine_similarity, +) @pytest.mark.unit @@ -30,7 +33,7 @@ def test_resnet18(ir): cos_sim = cosine_similarity(model(input), trt_mod(input)) assert ( cos_sim > COSINE_THRESHOLD, - f"Resnet50 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + f"Resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) # Clean up model env @@ -163,7 +166,7 @@ def test_resnet18_half(ir): cos_sim = cosine_similarity(model(input), trt_mod(input)) assert ( cos_sim > COSINE_THRESHOLD, - f"Resnet50 Half TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + f"Resnet18 Half TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) # Clean up model env diff --git a/py/torch_tensorrt/fx/converters/aten_ops_converters.py b/py/torch_tensorrt/fx/converters/aten_ops_converters.py index c86f2bd228..e8efc30ddf 100644 --- a/py/torch_tensorrt/fx/converters/aten_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/aten_ops_converters.py @@ -358,7 +358,7 @@ def aten_ops_cat( ) -> Union[TRTTensor, Sequence[TRTTensor]]: kwargs_new = { "tensors": args[0], - "dim": args[1], + "dim": args[1] if len(args) >= 2 else 0, } return acc_ops_converters.acc_ops_cat(network, target, None, kwargs_new, name) diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_cat_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_cat_aten.py index cfeb235af3..bd11747b15 100644 --- a/py/torch_tensorrt/fx/test/converters/aten_op/test_cat_aten.py +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_cat_aten.py @@ -53,6 +53,41 @@ def forward(self, x, y): expected_ops={torch.ops.aten.cat.default}, ) + def test_cat_no_dim(self): + class Cat(nn.Module): + def forward(self, x, y, z): + return torch.cat((x, y, z)) + + inputs = [torch.randn(2, 1, 3), torch.randn(1, 1, 3), torch.randn(3, 1, 3)] + self.run_test( + Cat(), + inputs, + expected_ops={torch.ops.aten.cat.default}, + ) + + def test_cat_dynamic_shape_no_dim(self): + class Cat(nn.Module): + def forward(self, x, y): + return torch.cat((x, y)) + + input_specs = [ + InputTensorSpec( + shape=(-1, 16, 3), + dtype=torch.float32, + shape_ranges=[((2, 16, 3), (3, 16, 3), (32, 16, 3))], + ), + InputTensorSpec( + shape=(-1, 16, 3), + dtype=torch.float32, + shape_ranges=[((2, 16, 3), (3, 16, 3), (32, 16, 3))], + ), + ] + self.run_test_with_dynamic_shape( + Cat(), + input_specs, + expected_ops={torch.ops.aten.cat.default}, + ) + if __name__ == "__main__": run_tests()