diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index 80a9f15c94..7db6c19636 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -141,6 +141,7 @@ jobs: cd tests/py/dynamo ${CONDA_RUN} python -m pip install --pre pytest timm transformers parameterized expecttest --use-deprecated=legacy-resolver ${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_fe_test_results.xml --ir dynamo models/test_models_export.py + ${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_fe_test_results.xml --ir dynamo models/test_dyn_models.py popd tests-py-torch-compile-be: diff --git a/cpp/include/torch_tensorrt/torch_tensorrt.h b/cpp/include/torch_tensorrt/torch_tensorrt.h index 29f860c8b3..adac75d984 100644 --- a/cpp/include/torch_tensorrt/torch_tensorrt.h +++ b/cpp/include/torch_tensorrt/torch_tensorrt.h @@ -60,6 +60,8 @@ class DataType { enum Value : int8_t { /// INT64 kLong, + /// FP64 + kDouble, /// FP32 kFloat, /// FP16 diff --git a/cpp/src/types.cpp b/cpp/src/types.cpp index 2be7fea338..69b956a162 100644 --- a/cpp/src/types.cpp +++ b/cpp/src/types.cpp @@ -97,6 +97,8 @@ at::ScalarType toAtenDataType(DataType value) { return at::kInt; case DataType::kLong: return at::kLong; + case DataType::kDouble: + return at::kDouble; case DataType::kBool: return at::kBool; case DataType::kFloat: @@ -119,7 +121,8 @@ nvinfer1::TensorFormat toTRTTensorFormat(TensorFormat value) { DataType::DataType(c10::ScalarType t) { TORCHTRT_CHECK( - t == at::kHalf || t == at::kFloat || t == at::kChar || t == at::kLong || t == at::kInt || t == at::kBool, + t == at::kHalf || t == at::kFloat || t == at::kChar || t == at::kLong || t == at::kDouble || t == at::kInt || + t == at::kBool, "Data type is unsupported (" << t << ")"); switch (t) { case at::kHalf: @@ -134,6 +137,9 @@ DataType::DataType(c10::ScalarType t) { case at::kLong: value = DataType::kLong; break; + case at::kDouble: + value = DataType::kDouble; + break; case at::kBool: value = DataType::kBool; break; diff --git a/docsrc/index.rst b/docsrc/index.rst index ded3b99c9d..18fb1185e8 100644 --- a/docsrc/index.rst +++ b/docsrc/index.rst @@ -42,6 +42,7 @@ User Guide * :ref:`getting_started_with_fx` * :ref:`ptq` * :ref:`runtime` +* :ref:`dynamic_shapes` * :ref:`use_from_pytorch` * :ref:`using_dla` @@ -54,6 +55,7 @@ User Guide user_guide/getting_started_with_fx_path user_guide/ptq user_guide/runtime + user_guide/dynamic_shapes user_guide/use_from_pytorch user_guide/using_dla diff --git a/docsrc/user_guide/dynamic_shapes.rst b/docsrc/user_guide/dynamic_shapes.rst new file mode 100644 index 0000000000..28320956c4 --- /dev/null +++ b/docsrc/user_guide/dynamic_shapes.rst @@ -0,0 +1,218 @@ +.. _runtime: + +Dynamic shapes with Torch-TensorRT +==================================== + +By default, you can run a pytorch model with varied input shapes and the output shapes are determined eagerly. +However, Torch-TensorRT is an AOT compiler which requires some prior information about the input shapes to compile and optimize the model. +In the case of dynamic input shapes, we must provide the (min_shape, opt_shape, max_shape) arguments so that the model can be optimized for +these range of input shapes. An example usage of static and dynamic shapes is as follows. + +NOTE: The following code uses dynamo IR. Incase of Torchscript IR, please swap out ``ir=dynamo`` with ``ir=ts`` and the behavior is exactly the same. + +.. code-block:: python + + import torch + import torch_tensorrt + + model = MyModel().eval().cuda() + # Compile with static shapes + inputs = torch_tensorrt.Input(shape=[1, 3, 224, 224], dtype=torch.float32) + # or compile with dynamic shapes + inputs = torch_tensorrt.Input(min_shape=[1, 3, 224, 224], + opt_shape=[4, 3, 224, 224], + max_shape=[8, 3, 224, 224], + dtype=torch.float32) + trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs) + +Under the hood +-------------- + +There are two phases of compilation when we use ``torch_tensorrt.compile`` API with ``ir=dynamo`` (default). + +- aten_tracer.trace (which uses torch.export to trace the graph with the given inputs) + +In the tracing phase, we use torch.export along with the constraints. In the case of +dynamic shaped inputs, the range can be provided to the tracing via constraints. Please +refer to this `docstring `_ +for detailed information on how to set constraints. In short, we create new inputs for +torch.export tracing and provide constraints on the min and max values(provided by the user), a particular dimension can take. +Please take a look at ``aten_tracer.py`` file to understand how this works under the hood. + +- dynamo.compile (which compiles a torch.fx.GraphModule object using TensorRT) + +In the conversion to TensorRT, we use the user provided dynamic shape inputs. +We perform shape analysis using dummy inputs (across min, opt and max shapes) and store the +intermediate output shapes which can be used in case the graph has a mix of Pytorch +and TensorRT submodules. + +Custom Constraints +------------------ + +Given an input ``x = torch_tensorrt.Input(min_shape, opt_shape, max_shape, dtype)``, +Torch-TensorRT automatically sets the constraints during ``torch.export`` tracing as follows + +.. code-block:: python + + for dim in constraint_dims: + if min_shape[dim] > 1: + constraints.append(min_shape[dim] <= dynamic_dim(trace_input, dim)) + if max_shape[dim] > 1: + constraints.append(dynamic_dim(trace_input, dim) <= max_shape[dim]) + +Sometimes, we might need to set additional constraints and Torchdynamo errors out if we don't specify them. +For example, in the case of BERT model compilation, there are two inputs and a constraint has to be set involving the sequence length size of these two inputs. + +.. code-block:: python + + constraints.append(dynamic_dim(trace_inputs[0], 0) == dynamic_dim(trace_inputs[1], 0)) + + +If you have to provide any custom constraints to your model, the overall workflow for model compilation using ``ir=dynamo`` would involve a few steps. + +.. code-block:: python + + import torch + import torch_tensorrt + from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions + # Assume the model has two inputs + model = MyModel() + torch_input_1 = torch.randn((1, 14), dtype=torch.int32).cuda() + torch_input_2 = torch.randn((1, 14), dtype=torch.int32).cuda() + + dynamic_inputs = [torch_tensorrt.Input(min_shape=[1, 14], + opt_shape=[4, 14], + max_shape=[8, 14], + dtype=torch.int32), + torch_tensorrt.Input(min_shape=[1, 14], + opt_shape=[4, 14], + max_shape=[8, 14], + dtype=torch.int32)] + + # Export the model with additional constraints + constraints = [] + # The following constraints are automatically added by Torch-TensorRT in the + # general case when you call torch_tensorrt.compile directly on MyModel() + constraints.append(dynamic_dim(torch_input_1, 0) < 8) + constraints.append(dynamic_dim(torch_input_2, 0) < 8) + # This is an additional constraint as instructed by Torchdynamo + constraints.append(dynamic_dim(torch_input_1, 0) == dynamic_dim(torch_input_2, 0)) + with unittest.mock.patch( + "torch._export.DECOMP_TABLE", get_decompositions(experimental_decompositions) + ): + graph_module = export( + model, (torch_input_1, torch_input_2), constraints=constraints + ).module() + + # Use the dynamo.compile API + trt_mod = torch_tensorrt.dynamo.compile(graph_module, inputs=dynamic_inputs, **compile_spec) + +Limitations +----------- + +If there are operations in the graph that use the dynamic dimension of the input, Pytorch +introduces ``torch.ops.aten.sym_size.int`` ops in the graph. Currently, we cannot handle these operators and +the compilation results in undefined behavior. We plan to add support for these operators and implement +robust support for shape tensors in the next release. Here is an example of the limitation described above + +.. code-block:: python + + import torch + import torch_tensorrt + + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1)) + + def forward(self, x): + x = self.avgpool(x) + out = torch.flatten(x, 1) + return out + + model = MyModel().eval().cuda() + # Compile with dynamic shapes + inputs = torch_tensorrt.Input(min_shape=(1, 512, 1, 1), + opt_shape=(4, 512, 1, 1), + max_shape=(8, 512, 1, 1), + dtype=torch.float32) + trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs) + + +The traced graph of `MyModule()` looks as follows + +.. code-block:: python + + Post export graph: graph(): + %arg0_1 : [num_users=2] = placeholder[target=arg0_1] + %mean : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%arg0_1, [-1, -2], True), kwargs = {}) + %sym_size : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%arg0_1, 0), kwargs = {}) + %view : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%mean, [%sym_size, 512]), kwargs = {}) + return (view,) + + +Here the ``%sym_size`` node captures the dynamic batch and uses it in the ``aten.view`` layer. This requires shape tensors support +which would be a part of our next release. + +Workaround (BERT static compilation example) +------------------------------------------ + +In the case where you encounter the issues mentioned in the **Limitations** section, +you can compile the model (static mode) with max input size that can be provided. In the cases of smaller inputs, +we can pad them accordingly. This is only a workaround until we address the limitations. + +.. code-block:: python + + import torch + import torch_tensorrt + from transformers.utils.fx import symbolic_trace as transformers_trace + + model = BertModel.from_pretrained("bert-base-uncased").cuda().eval() + + # Input sequence length is 20. + input1 = torch.randint(0, 5, (1, 20), dtype=torch.int32).to("cuda") + input2 = torch.randint(0, 5, (1, 20), dtype=torch.int32).to("cuda") + + model = transformers_trace(model, input_names=["input_ids", "attention_mask"]).eval().cuda() + trt_mod = torch_tensorrt.compile(model, inputs=[input1, input2], **compile_spec) + model_outputs = model(input, input2) + + # If you have a sequence of length 14, pad 6 zero tokens and run inference + # or recompile for sequence length of 14. + input1 = torch.randint(0, 5, (1, 14), dtype=torch.int32).to("cuda") + input2 = torch.randint(0, 5, (1, 14), dtype=torch.int32).to("cuda") + trt_mod = torch_tensorrt.compile(model, inputs=[input1, input2], **compile_spec) + model_outputs = model(input, input2) + + +Dynamic shapes with ir=torch_compile +------------------------------------ + +``torch_tensorrt.compile(model, inputs, ir="torch_compile")`` returns a torch.compile boxed function with the backend +configured to Tensorrt. In the case of ``ir=torch_compile``, users have to recompile for different input shapes. +In the future, we plan to explore the option of compiling with dynamic shapes in the first execution of the model. + +.. code-block:: python + + import torch + import torch_tensorrt + + model = MyModel().eval().cuda() + inputs = torch.randn((1, 3, 224, 224), dtype=float32) + trt_gm = torch_tensorrt.compile(model, ir="torch_compile", inputs) + # Compilation happens when you call the model + trt_gm(inputs) + + # Recompilation happens with modified batch size + inputs_bs2 = torch.randn((2, 3, 224, 224), dtype=torch.float32) + trt_gm = torch_tensorrt.compile(model, ir="torch_compile", inputs_bs2) + + + + + + + + + + diff --git a/py/torch_tensorrt/_Input.py b/py/torch_tensorrt/_Input.py index 3995eca1dd..6e43a23903 100644 --- a/py/torch_tensorrt/_Input.py +++ b/py/torch_tensorrt/_Input.py @@ -46,6 +46,7 @@ class _ShapeMode(Enum): low_tensor_domain_incl: float = 0.0 high_tensor_domain_excl: float = low_tensor_domain_incl + DOMAIN_OFFSET torch_dtype: torch.dtype = torch.float32 + torch_tensor: torch.Tensor = None def __init__(self, *args: Any, **kwargs: Any) -> None: """__init__ Method for torch_tensorrt.Input @@ -171,6 +172,14 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self.tensor_domain = Input._parse_tensor_domain(domain) + if "torch_tensor" in kwargs: + self.torch_tensor = kwargs["torch_tensor"] + else: + if self.shape_mode == Input._ShapeMode.DYNAMIC: + self.torch_tensor = self.example_tensor("opt_shape") + else: + self.torch_tensor = self.example_tensor() + def __str__(self) -> str: if self.shape_mode == Input._ShapeMode.STATIC: return "Input(shape={}, dtype={}, format={}, domain=[{}, {}))".format( @@ -220,6 +229,8 @@ def _parse_dtype(dtype: Any) -> _enums.dtype: return _enums.dtype.half elif dtype == torch.float: return _enums.dtype.float + elif dtype == torch.float64: + return _enums.dtype.double elif dtype == torch.bool: return _enums.dtype.bool else: @@ -249,6 +260,8 @@ def _to_torch_dtype(dtype: _enums.dtype) -> torch.dtype: return torch.float elif dtype == _enums.dtype.bool: return torch.bool + elif dtype == _enums.dtype.double: + return torch.float64 else: # Default torch_dtype used in FX path return torch.float32 @@ -354,7 +367,7 @@ def from_tensor( ) else torch.channels_last ) - return cls(shape=t.shape, dtype=t.dtype, format=frmt) + return cls(shape=t.shape, dtype=t.dtype, format=frmt, torch_tensor=t) @classmethod def from_tensors( diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index 08e7f7d424..67bf6d523e 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -214,19 +214,20 @@ def compile( ) return compiled_fx_module elif target_ir == _IRType.dynamo: + # Prepare torch and torchtrt inputs import collections.abc - from torch_tensorrt import Device - from torch_tensorrt.dynamo.utils import prepare_inputs, to_torch_device + from torch_tensorrt.dynamo.utils import prepare_inputs - if not isinstance(inputs, collections.abc.Sequence): - inputs = [inputs] - device = kwargs.get("device", Device._current_device()) - torchtrt_inputs, torch_inputs = prepare_inputs(inputs, to_torch_device(device)) - module = torch_tensorrt.dynamo.trace(module, torch_inputs, **kwargs) + if not isinstance(input_list, collections.abc.Sequence): + input_list = [input_list] + + # Export the module + torchtrt_inputs = prepare_inputs(input_list) + module = torch_tensorrt.dynamo.trace(module, torchtrt_inputs, **kwargs) compiled_aten_module: torch.fx.GraphModule = dynamo_compile( module, - inputs=input_list, + inputs=torchtrt_inputs, enabled_precisions=enabled_precisions_set, **kwargs, ) diff --git a/py/torch_tensorrt/csrc/tensorrt_classes.cpp b/py/torch_tensorrt/csrc/tensorrt_classes.cpp index ac2dffb4b8..4794a679eb 100644 --- a/py/torch_tensorrt/csrc/tensorrt_classes.cpp +++ b/py/torch_tensorrt/csrc/tensorrt_classes.cpp @@ -18,6 +18,8 @@ std::string to_str(DataType value) { return "Float"; case DataType::kLong: return "Long"; + case DataType::kDouble: + return "Double"; default: return "Unknown data type"; } @@ -33,6 +35,8 @@ nvinfer1::DataType toTRTDataType(DataType value) { return nvinfer1::DataType::kINT32; case DataType::kLong: return nvinfer1::DataType::kINT32; + case DataType::kDouble: + return nvinfer1::DataType::kFLOAT; case DataType::kBool: return nvinfer1::DataType::kBOOL; case DataType::kFloat: @@ -58,6 +62,8 @@ at::ScalarType toAtenDataType(DataType value) { return at::kBool; case DataType::kFloat: return at::kFloat; + case DataType::kDouble: + return at::kDouble; case DataType::kUnknown: return at::kFloat; default: diff --git a/py/torch_tensorrt/csrc/tensorrt_classes.h b/py/torch_tensorrt/csrc/tensorrt_classes.h index 28321b0571..9bdd00b7e0 100644 --- a/py/torch_tensorrt/csrc/tensorrt_classes.h +++ b/py/torch_tensorrt/csrc/tensorrt_classes.h @@ -27,7 +27,7 @@ namespace pyapi { return static_cast(field_name); \ } -enum class DataType : int8_t { kLong, kFloat, kHalf, kChar, kInt32, kBool, kUnknown }; +enum class DataType : int8_t { kLong, kDouble, kFloat, kHalf, kChar, kInt32, kBool, kUnknown }; std::string to_str(DataType value); nvinfer1::DataType toTRTDataType(DataType value); at::ScalarType toAtenDataType(DataType value); diff --git a/py/torch_tensorrt/csrc/torch_tensorrt_py.cpp b/py/torch_tensorrt/csrc/torch_tensorrt_py.cpp index b3880b335a..33c7e27398 100644 --- a/py/torch_tensorrt/csrc/torch_tensorrt_py.cpp +++ b/py/torch_tensorrt/csrc/torch_tensorrt_py.cpp @@ -246,6 +246,8 @@ PYBIND11_MODULE(_C, m) { .value("int32", DataType::kInt32, "32 bit integer number") .value("long", DataType::kLong, "64 bit integer number") .value("int64", DataType::kLong, "64 bit integer number") + .value("double", DataType::kDouble, "64 bit floating point number") + .value("float64", DataType::kDouble, "64 bit floating point number") .value("bool", DataType::kBool, "Boolean value") .value("unknown", DataType::kUnknown, "Unknown data type") .export_values(); diff --git a/py/torch_tensorrt/dynamo/aten_tracer.py b/py/torch_tensorrt/dynamo/aten_tracer.py index be2f2efd9c..da346635a2 100644 --- a/py/torch_tensorrt/dynamo/aten_tracer.py +++ b/py/torch_tensorrt/dynamo/aten_tracer.py @@ -2,16 +2,32 @@ import logging import unittest.mock -from typing import Any, Tuple +from typing import Any, List, Tuple import torch -from torch._export import export -from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions -from torch_tensorrt.dynamo.utils import set_log_level +from torch._export import dynamic_dim, export +from torch_tensorrt._Input import Input +from torch_tensorrt.dynamo._defaults import default_device +from torch_tensorrt.dynamo.lowering import get_decompositions +from torch_tensorrt.dynamo.utils import get_torch_inputs, set_log_level, to_torch_device logger = logging.getLogger(__name__) +def get_random_tensor( + shape: List[Any], dtype: torch.dtype, device: torch.device +) -> torch.Tensor: + if dtype == torch.int32 or dtype == torch.int64: + return torch.randint(2, 10, shape, dtype=dtype, device=device) + elif dtype in (torch.float64, torch.float32, torch.float16): + return torch.randn(shape, dtype=dtype, device=device) + else: + logger.critical( + "Invalid dtype detected in creating input tensors for tracing the graph." + ) + raise + + def trace( model: torch.nn.Module | torch.fx.GraphModule, inputs: Tuple[Any, ...], @@ -21,13 +37,52 @@ def trace( if "debug" in kwargs and kwargs["debug"]: set_log_level(logger.parent, logging.DEBUG) + # Determine the dynamic dimension and setup constraints to input dimensions as dictated by TensorRT + # Torch dynamo does not allow 0/1 value for dynamic dimensions + # for inputs during tracing. Hence we create new inputs for export + device = to_torch_device(kwargs.get("device", default_device())) + torch_inputs = get_torch_inputs(inputs, device) + trace_inputs = [] + constraints = [] + for idx, input in enumerate(inputs): + if input.shape_mode == Input._ShapeMode.DYNAMIC: + min_shape = input.shape["min_shape"] + opt_shape = input.shape["opt_shape"] + max_shape = input.shape["max_shape"] + assert len(min_shape) == len(opt_shape) == len(max_shape) + + constraint_dims = [] + new_shape = [] + for dim in range(len(min_shape)): + if min_shape[dim] == opt_shape[dim] == max_shape[dim]: + new_shape.append(torch_inputs[idx].shape[dim]) + else: + constraint_dims.append(dim) + if torch_inputs[idx].shape[dim] == 1: + new_shape.append(torch_inputs[idx].shape[dim] + 1) + else: + new_shape.append(torch_inputs[idx].shape[dim]) + + trace_input = get_random_tensor(new_shape, torch_inputs[idx].dtype, device) + + for dim in constraint_dims: + if min_shape[dim] > 1: + constraints.append(min_shape[dim] <= dynamic_dim(trace_input, dim)) + if max_shape[dim] > 1: + constraints.append(dynamic_dim(trace_input, dim) <= max_shape[dim]) + trace_inputs.append(trace_input) + else: + trace_inputs.append(torch_inputs[idx]) + experimental_decompositions = kwargs.get( "enable_experimental_decompositions", False ) with unittest.mock.patch( "torch._export.DECOMP_TABLE", get_decompositions(experimental_decompositions) ): - graph_module = export(model, tuple(inputs)).module() - graph_module = apply_lowering_passes(graph_module, inputs) + graph_module = export( + model, tuple(trace_inputs), constraints=constraints + ).module() + logger.debug("Post export graph: " + str(graph_module.graph)) return graph_module diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index 7b98079564..b30da1ffb8 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -16,7 +16,11 @@ repair_input_aliasing, ) from torch_tensorrt.dynamo.lowering._pre_aot_lowering import pre_aot_substitutions -from torch_tensorrt.dynamo.utils import parse_dynamo_kwargs, set_log_level +from torch_tensorrt.dynamo.utils import ( + parse_dynamo_kwargs, + prepare_inputs, + set_log_level, +) logger = logging.getLogger(__name__) @@ -89,9 +93,10 @@ def _pretraced_backend( gm = apply_lowering_passes(gm, sample_inputs) + torchtrt_inputs = prepare_inputs(sample_inputs) trt_compiled = compile_module( gm, - sample_inputs, + torchtrt_inputs, settings=settings, ) return trt_compiled diff --git a/py/torch_tensorrt/dynamo/compile.py b/py/torch_tensorrt/dynamo/compile.py index d512fd1bf2..0ef52edd43 100644 --- a/py/torch_tensorrt/dynamo/compile.py +++ b/py/torch_tensorrt/dynamo/compile.py @@ -1,6 +1,5 @@ from __future__ import annotations -import collections.abc import logging from typing import Any, List, Optional, Sequence, Set, Tuple, Union @@ -10,6 +9,7 @@ from torch_tensorrt._enums import ( # TODO: Should probabably be the TRT EngineCapability Enum EngineCapability, ) +from torch_tensorrt._Input import Input from torch_tensorrt.dynamo import CompilationSettings, partitioning from torch_tensorrt.dynamo._defaults import ( DEBUG, @@ -31,8 +31,9 @@ convert_module, repair_long_or_double_inputs, ) +from torch_tensorrt.dynamo.lowering import apply_lowering_passes from torch_tensorrt.dynamo.utils import ( - prepare_inputs, + get_torch_inputs, set_log_level, to_torch_device, to_torch_tensorrt_device, @@ -75,6 +76,10 @@ def compile( if debug: set_log_level(logger.parent, logging.DEBUG) + # Apply lowering on the graph module + torch_inputs = get_torch_inputs(inputs, device) + gm = apply_lowering_passes(gm, torch_inputs) + enabled_precisions = set(enabled_precisions) logger.warning( @@ -87,13 +92,8 @@ def compile( "require_full_compilation}" ) - if not isinstance(inputs, collections.abc.Sequence): - inputs = [inputs] - device = to_torch_tensorrt_device(device) - _, torch_inputs = prepare_inputs(inputs, to_torch_device(device)) - if ( torch.float16 in enabled_precisions or torch_tensorrt.dtype.half in enabled_precisions @@ -134,12 +134,12 @@ def compile( settings = CompilationSettings(**compilation_options) logger.info("Compilation Settings: %s\n", settings) - return compile_module(gm, torch_inputs, settings) + return compile_module(gm, inputs, settings) def compile_module( gm: torch.fx.GraphModule, - sample_inputs: Sequence[torch.Tensor], + sample_inputs: Sequence[Input], settings: CompilationSettings = CompilationSettings(), ) -> torch.fx.GraphModule: """Compile a traced FX module @@ -153,6 +153,7 @@ def compile_module( Returns: Compiled FX GraphModule """ + # Check the number of supported operations in the graph num_supported_ops, total_ops = partitioning.get_graph_converter_support( gm, settings.debug, settings.torch_executed_ops @@ -203,7 +204,6 @@ def compile_module( # Store TRT replicas of Torch subgraphs trt_modules = {} - # Iterate over all components that can be accelerated # Generate the corresponding TRT Module for those for name, _ in partitioned_module.named_children(): @@ -213,19 +213,30 @@ def compile_module( submodule = getattr(partitioned_module, name) - logger.debug( - "Submodule name: " + str(name) + " Graph: \n" + str(submodule.graph) - ) - # Get submodule inputs + # Get the submodule inputs for min, opt, max shapes of the graph inputs submodule_inputs = partitioning.get_submod_inputs( - partitioned_module, submodule, sample_inputs + partitioned_module, + submodule, + sample_inputs, + to_torch_device(settings.device), + ) + + logger.debug( + "Submodule name: %s\n Input shapes: %s\n %s", + str(name), + [input.shape for input in submodule_inputs], + str(submodule.graph), ) assert submodule_inputs is not None # Handle long/double inputs if requested by the user if settings.truncate_long_and_double: submodule_inputs = repair_long_or_double_inputs( - partitioned_module, submodule, submodule_inputs, name + partitioned_module, + submodule, + submodule_inputs, + to_torch_device(settings.device), + name, ) # Create TRT Module from submodule diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 206636a637..61f3f6b6f5 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -285,6 +285,7 @@ def placeholder(self, target: str, args: Any, kwargs: Any) -> trt.ITensor: self.optimization_profiles[0].set_shape( target, min_shape, opt_shape, max_shape ) + assert len(min_shape) == len(opt_shape) == len(max_shape) for i in range(len(min_shape)): if min_shape[i] == opt_shape[i] == max_shape[i]: diff --git a/py/torch_tensorrt/dynamo/conversion/conversion.py b/py/torch_tensorrt/dynamo/conversion/conversion.py index 5555686e77..77a66c7c6d 100644 --- a/py/torch_tensorrt/dynamo/conversion/conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/conversion.py @@ -9,11 +9,12 @@ from torch_tensorrt.dynamo import CompilationSettings from torch_tensorrt.dynamo.conversion import TRTInterpreter from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule +from torch_tensorrt.dynamo.utils import get_torch_inputs def convert_module( module: torch.fx.GraphModule, - inputs: Sequence[torch.Tensor], + inputs: Sequence[Input], settings: CompilationSettings = CompilationSettings(), name: str = "", ) -> PythonTorchTensorRTModule | TorchTensorRTModule: @@ -28,15 +29,17 @@ def convert_module( """ # Specify module output data types to ensure TRT output types agree with # that of the equivalent Torch module - module_outputs = module(*inputs) + torch_inputs = get_torch_inputs(inputs, settings.device) + module_outputs = module(*torch_inputs) if not isinstance(module_outputs, (list, tuple)): module_outputs = [module_outputs] output_dtypes = [output.dtype for output in module_outputs] + interpreter = TRTInterpreter( module, - Input.from_tensors(inputs, disable_memory_format_check=True), + inputs, logger_level=(trt.Logger.VERBOSE if settings.debug else trt.Logger.WARNING), output_dtypes=output_dtypes, compilation_settings=settings, diff --git a/py/torch_tensorrt/dynamo/conversion/truncate_long_and_double.py b/py/torch_tensorrt/dynamo/conversion/truncate_long_and_double.py index c6c10f475a..9390bc3bde 100644 --- a/py/torch_tensorrt/dynamo/conversion/truncate_long_and_double.py +++ b/py/torch_tensorrt/dynamo/conversion/truncate_long_and_double.py @@ -4,6 +4,8 @@ import torch from torch.fx.node import _get_qualified_name +from torch_tensorrt._Input import Input +from torch_tensorrt.dynamo.utils import get_torch_inputs def _extract_downstream_get_nodes( @@ -157,9 +159,10 @@ def _repair_64bit_input( def repair_long_or_double_inputs( parent_graph: torch.fx.GraphModule, submodule: torch.fx.GraphModule, - submodule_inputs: Sequence[torch.Tensor], + submodule_inputs: Sequence[Input], + device: torch.device, submodule_name: Optional[str] = None, -) -> Sequence[torch.Tensor]: +) -> Sequence[Input]: """Fixes all Long/Double type inputs to a TRT-accelerated subgraph In-Place modifies the provided graph @@ -175,12 +178,13 @@ def repair_long_or_double_inputs( Returns: New submodule inputs, updated accordingly with long/double truncation """ + submodule_torch_inputs = get_torch_inputs(submodule_inputs, device) num_submodule_inputs = len(submodule_inputs) repaired_outputs_once = False # For each input to the TRT subgraph, check if its type is long/double for position in range(num_submodule_inputs): - param = submodule_inputs[position] + param = submodule_torch_inputs[position] # If the data type of the input is long/double, insert necessary # casts to replace the operation @@ -188,7 +192,7 @@ def repair_long_or_double_inputs( # Ensure outputs are only repaired once per submodule to avoid # unnecessary ops showing up in the graph if not repaired_outputs_once: - submodule_outputs = submodule(*submodule_inputs) + submodule_outputs = submodule(*submodule_torch_inputs) _repair_64bit_input( parent_graph, @@ -202,12 +206,17 @@ def repair_long_or_double_inputs( # Repair submodule inputs in accordance with inserted casts dtype_32bit = torch.int32 if (param.dtype == torch.int64) else torch.float32 - submodule_inputs = ( - list(submodule_inputs[:position]) + submodule_torch_inputs = ( + list(submodule_torch_inputs[:position]) + [ param.to(dtype_32bit), ] - + list(submodule_inputs[position + 1 :]) + + list(submodule_torch_inputs[position + 1 :]) ) + # Set the 32bit inputs and their types to the submodule Inputs + for idx in range(len(submodule_inputs)): + submodule_inputs[idx].torch_tensor = submodule_torch_inputs[idx] + submodule_inputs[idx].torch_dtype = submodule_torch_inputs[idx].dtype + return submodule_inputs diff --git a/py/torch_tensorrt/dynamo/partitioning/common.py b/py/torch_tensorrt/dynamo/partitioning/common.py index 8c36668d00..14c068260f 100644 --- a/py/torch_tensorrt/dynamo/partitioning/common.py +++ b/py/torch_tensorrt/dynamo/partitioning/common.py @@ -1,9 +1,14 @@ +import logging from typing import Any, Optional, Sequence, Set, Tuple import torch from torch.fx.node import _get_qualified_name +from torch_tensorrt._Input import Input from torch_tensorrt.dynamo._defaults import DEBUG from torch_tensorrt.dynamo.lowering import SUBSTITUTION_REGISTRY +from torch_tensorrt.dynamo.utils import get_torch_inputs, input_is_dynamic + +logger = logging.getLogger(__name__) DEFAULT_SINGLE_NODE_PARTITIONS: Set[str] = { _get_qualified_name(to_replace.new_operator) @@ -14,7 +19,8 @@ def get_submod_inputs( mod: torch.fx.GraphModule, submod: torch.fx.GraphModule, - inputs: Sequence[torch.Tensor], + inputs: Sequence[Input], + device: torch.device, ) -> Optional[Sequence[torch.Tensor]]: """Helper function to get inputs to a Torch submodule @@ -25,17 +31,63 @@ def get_submod_inputs( Returns: Sequence of Tensors representing inputs to child module """ - acc_inputs = None + acc_inputs: Any = None def get_input(self: Any, inputs: Sequence[torch.Tensor]) -> None: nonlocal acc_inputs acc_inputs = inputs return + # Register a hook to capture submodule input handle = submod.register_forward_pre_hook(get_input) - mod(*inputs) - handle.remove() - return acc_inputs + # Iterate over min, opt, max shapes for dynamic inputs + inputs_map = {} + + if input_is_dynamic(inputs): + for mode in ["min_shape", "opt_shape", "max_shape"]: + torch_inputs = get_torch_inputs(inputs, device, mode) + mod(*torch_inputs) + inputs_map[mode] = acc_inputs + handle.remove() + else: + torch_inputs = get_torch_inputs(inputs, device) + mod(*torch_inputs) + handle.remove() + assert isinstance(acc_inputs, tuple) + return [ + Input(shape=acc_input.shape, dtype=acc_input.dtype) + for acc_input in acc_inputs + ] + + num_submodule_inputs = ( + len(inputs_map["min_shape"]) if inputs_map["min_shape"] else 0 + ) + submodule_inputs = [] + for idx in range(num_submodule_inputs): + if not isinstance(inputs_map["min_shape"][idx], torch.Tensor): + input_val = torch.tensor(inputs_map["opt_shape"][idx], dtype=torch.int32) + logger.warning( + "Detected a zero-dimensional input. This might be a shape tensor input which is not currently supported. This might result in undefined behavior" + ) + submodule_inputs.append( + Input( + shape=[1], + torch_tensor=input_val, + dtype=input_val.dtype, + ) + ) + else: + submodule_inputs.append( + Input( + min_shape=inputs_map["min_shape"][idx].shape, + opt_shape=inputs_map["opt_shape"][idx].shape, + max_shape=inputs_map["max_shape"][idx].shape, + torch_tensor=inputs_map["opt_shape"][idx], + dtype=inputs_map["opt_shape"][idx].dtype, + ) + ) + + return submodule_inputs def get_graph_converter_support( diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 28149c8fde..97046ba421 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -63,6 +63,35 @@ def cosine_similarity(gt_tensor: torch.Tensor, pred_tensor: torch.Tensor) -> flo return res +def input_is_dynamic(inputs: Sequence[Union[Input, torch.Tensor]]) -> bool: + """ + Return true if the provided inputs are `torch_tensorrt.Input` objects and have dynamic shapes. + """ + return not any(isinstance(input, torch.Tensor) for input in inputs) and any( + input.shape_mode == Input._ShapeMode.DYNAMIC for input in inputs + ) + + +def get_torch_inputs( + inputs: Sequence[Input], device: Union[Device, torch.device, str], mode: str = "" +) -> Sequence[torch.tensor]: + """ + Return the torch_tensor from the Input object. If mode is set, this implies + user is using dynamic shaped inputs and return the corresponding input based + on the mode requested. + """ + device = to_torch_device(device) + if mode: + return [ + input.example_tensor(mode).to(device) + for input in inputs + if isinstance(input, Input) + ] + return [ + input.torch_tensor.to(device) for input in inputs if isinstance(input, Input) + ] + + def set_log_level(parent_logger: Any, level: Any) -> None: """ Sets the log level to the user provided level. @@ -75,49 +104,37 @@ def set_log_level(parent_logger: Any, level: Any) -> None: def prepare_inputs( inputs: Input | torch.Tensor | Sequence[Any] | Dict[Any, Any], - device: torch.device = torch.device("cuda"), ) -> Any: if isinstance(inputs, Input): - if isinstance(inputs.shape, dict): - return inputs, inputs.example_tensor( - optimization_profile_field="opt_shape" - ).to(device) - else: - return inputs, inputs.example_tensor().to(device) + return inputs elif isinstance(inputs, torch.Tensor): - return Input.from_tensor(inputs), inputs + return Input.from_tensor(inputs) elif isinstance(inputs, list): torchtrt_input_list = [] - torch_input_list = [] for input_obj in inputs: - torchtrt_input, torch_input = prepare_inputs(input_obj) + torchtrt_input = prepare_inputs(input_obj) torchtrt_input_list.append(torchtrt_input) - torch_input_list.append(torch_input) - return torchtrt_input_list, torch_input_list + return torchtrt_input_list elif isinstance(inputs, tuple): torchtrt_inputs_tup = [] - torch_inputs_tup = [] for input_obj in inputs: - torchtrt_input, torch_input = prepare_inputs(input_obj) + torchtrt_input = prepare_inputs(input_obj) torchtrt_inputs_tup.append(torchtrt_input) - torch_inputs_tup.append(torch_input) - return tuple(torchtrt_inputs_tup), tuple(torch_inputs_tup) + return tuple(torchtrt_inputs_tup) elif isinstance(inputs, dict): torchtrt_inputs_dict: Dict[Any, Any] = dict() - torch_inputs_dict: Dict[Any, Any] = dict() for key, input_obj in inputs.items(): - torchtrt_input, torch_input = prepare_inputs(input_obj) + torchtrt_input = prepare_inputs(input_obj) torchtrt_inputs_dict[key] = torchtrt_input - torch_inputs_dict[key] = torch_input - return torchtrt_inputs_dict, torch_inputs_dict + return torchtrt_inputs_dict else: raise ValueError( diff --git a/tests/py/dynamo/models/test_dyn_models.py b/tests/py/dynamo/models/test_dyn_models.py new file mode 100644 index 0000000000..057a95879d --- /dev/null +++ b/tests/py/dynamo/models/test_dyn_models.py @@ -0,0 +1,113 @@ +import unittest + +import pytest +import timm +import torch +import torch_tensorrt as torchtrt +from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity + +assertions = unittest.TestCase() + + +@pytest.mark.unit +def test_base_dynamic(ir): + """ + Tests the model (which is fully convertible) with dynamic shapes + """ + + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True) + self.relu = torch.nn.ReLU() + + def forward(self, x): + out = self.conv(x) + out = self.relu(out) + return out + + model = MyModule().eval().cuda() + input = torch.randn((1, 3, 224, 224)).to("cuda") + + compile_spec = { + "inputs": [ + torchtrt.Input( + min_shape=(1, 3, 224, 224), + opt_shape=(4, 3, 224, 224), + max_shape=(8, 3, 224, 224), + dtype=torch.float32, + ) + ], + "device": torchtrt.Device("cuda:0"), + "enabled_precisions": {torch.float}, + "ir": ir, + "pass_through_build_failures": True, + "optimization_level": 1, + "min_block_size": 1, + } + + trt_mod = torchtrt.compile(model, **compile_spec) + cos_sim = cosine_similarity(model(input), trt_mod(input)[0]) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_base_dynamic model TRT outputs don't match with the pytorch model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + # Clean up model env + torch._dynamo.reset() + + with torch.no_grad(): + torch.cuda.empty_cache() + + +@pytest.mark.unit +def test_base_dynamic_fallback(ir): + """ + Tests the model (which is fully convertible) with dynamic shapes + """ + + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True) + self.relu = torch.nn.ReLU() + + def forward(self, x): + out = self.conv(x) + out = torch.abs(out) + out = self.relu(out) + return out + + model = MyModule().eval().cuda() + input = torch.randn((1, 3, 224, 224)).to("cuda") + + compile_spec = { + "inputs": [ + torchtrt.Input( + min_shape=(1, 3, 224, 224), + opt_shape=(4, 3, 224, 224), + max_shape=(8, 3, 224, 224), + dtype=torch.float32, + ) + ], + "device": torchtrt.Device("cuda:0"), + "enabled_precisions": {torch.float}, + "ir": ir, + "pass_through_build_failures": True, + "optimization_level": 1, + "torch_executed_ops": "torch.ops.aten.abs.default", + "min_block_size": 1, + } + + trt_mod = torchtrt.compile(model, **compile_spec) + cos_sim = cosine_similarity(model(input), trt_mod(input)[0]) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_base_dynamic model TRT outputs don't match with the pytorch model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + # Clean up model env + torch._dynamo.reset() + + with torch.no_grad(): + torch.cuda.empty_cache() diff --git a/tests/py/dynamo/models/test_models_export.py b/tests/py/dynamo/models/test_models_export.py index d084ce68f6..fd7b40592a 100644 --- a/tests/py/dynamo/models/test_models_export.py +++ b/tests/py/dynamo/models/test_models_export.py @@ -154,6 +154,59 @@ def test_bert_base_uncased(ir): torch._dynamo.reset() +@pytest.mark.unit +def test_bert_base_uncased(ir): + model = BertModel.from_pretrained("bert-base-uncased").cuda().eval() + input = torch.randint(0, 1, (1, 14), dtype=torch.int32).to("cuda") + input2 = torch.randint(0, 1, (1, 14), dtype=torch.int32).to("cuda") + model = ( + transformers_trace(model, input_names=["input_ids", "attention_mask"]) + .eval() + .cuda() + ) + + compile_spec = { + "inputs": [ + torchtrt.Input( + input.shape, + dtype=input.dtype, + format=torch.contiguous_format, + ), + torchtrt.Input( + input.shape, + dtype=input.dtype, + format=torch.contiguous_format, + ), + ], + "device": torchtrt.Device("cuda:0"), + "enabled_precisions": {torch.float}, + "truncate_long_and_double": True, + "ir": ir, + "min_block_size": 10, + "torch_executed_ops": {"torch.ops.aten.gelu.default"}, + } + trt_mod = torchtrt.compile(model, **compile_spec) + model_outputs = model(input, input2) + trt_model_outputs = trt_mod(input, input2) + assertions.assertTrue( + len(model_outputs) == len(trt_model_outputs), + msg=f"Number of outputs for BERT model compilation is different with Pytorch {len(model_outputs)} and TensorRT {len(trt_model_outputs)}. Please check the compilation.", + ) + for index, key in enumerate(model_outputs): + out, trt_out = model_outputs[key], trt_model_outputs[index] + cos_sim = cosine_similarity(out, trt_out) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"HF BERT base-uncased TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + # Clean up model env + torch._dynamo.reset() + + with torch.no_grad(): + torch.cuda.empty_cache() + + @pytest.mark.unit def test_resnet18_half(ir): model = models.resnet18(pretrained=True).eval().to("cuda").half() diff --git a/tests/py/dynamo/runtime/test_compiler_utils.py b/tests/py/dynamo/runtime/test_compiler_utils.py index afd21c3079..02b0d63523 100644 --- a/tests/py/dynamo/runtime/test_compiler_utils.py +++ b/tests/py/dynamo/runtime/test_compiler_utils.py @@ -60,23 +60,17 @@ def test_cast_str_device(self): class TestPrepareInputs(unittest.TestCase): def test_prepare_single_tensor_input(self): inputs = [torch.ones((4, 4))] - prepared_inputs_trt, prepared_inputs_torch = prepare_inputs(inputs) + prepared_inputs_trt = prepare_inputs(inputs) self.assertTrue( same_output_format(inputs, prepared_inputs_trt, enforce_tensor_type=False) ) - self.assertTrue( - same_output_format(inputs, prepared_inputs_torch, enforce_tensor_type=False) - ) def test_prepare_trt_input(self): inputs = [torch_tensorrt.Input(shape=(4, 3), dtype=torch.float)] - prepared_inputs_trt, prepared_inputs_torch = prepare_inputs(inputs) + prepared_inputs_trt = prepare_inputs(inputs) self.assertTrue( same_output_format(inputs, prepared_inputs_trt, enforce_tensor_type=False) ) - self.assertTrue( - same_output_format(inputs, prepared_inputs_torch, enforce_tensor_type=False) - ) def test_prepare_mixed_type_compound_tensor_input(self): inputs = { @@ -89,13 +83,10 @@ def test_prepare_mixed_type_compound_tensor_input(self): (torch.rand((5, 1)), torch_tensorrt.Input(shape=(2, 3))), ), } - prepared_inputs_trt, prepared_inputs_torch = prepare_inputs(inputs) + prepared_inputs_trt = prepare_inputs(inputs) self.assertTrue( same_output_format(inputs, prepared_inputs_trt, enforce_tensor_type=False) ) - self.assertTrue( - same_output_format(inputs, prepared_inputs_torch, enforce_tensor_type=False) - ) if __name__ == "__main__":