diff --git a/py/torch_tensorrt/fx/__init__.py b/py/torch_tensorrt/fx/__init__.py index fa0afc33d1..aeae62d86d 100644 --- a/py/torch_tensorrt/fx/__init__.py +++ b/py/torch_tensorrt/fx/__init__.py @@ -6,5 +6,6 @@ tensorrt_converter, ) from .fx2trt import TRTInterpreter, TRTInterpreterResult # noqa -from .input_tensor_spec import InputTensorSpec # noqa +from .input_tensor_spec import generate_input_specs, InputTensorSpec # noqa +from .lower_setting import LowerSetting # noqa from .trt_module import TRTModule # noqa diff --git a/py/torch_tensorrt/fx/input_tensor_spec.py b/py/torch_tensorrt/fx/input_tensor_spec.py index 79c572d9b6..9429a7661f 100644 --- a/py/torch_tensorrt/fx/input_tensor_spec.py +++ b/py/torch_tensorrt/fx/input_tensor_spec.py @@ -1,4 +1,4 @@ -from typing import Iterable, List, NamedTuple, Sequence, Tuple +from typing import Iterable, List, NamedTuple, Optional, Sequence, Tuple import torch @@ -6,6 +6,61 @@ from .utils import get_dynamic_dims +def generate_input_specs( + inputs, lower_setting, additional_inputs=None, fixed_shape=False +): + # AIT lower setting doesn't have explicit_batch_dimension field and + # we just return None. + if not hasattr(lower_setting, "explicit_batch_dimension"): + return None + + if not lower_setting.explicit_batch_dimension or fixed_shape: + return InputTensorSpec.from_tensors(inputs) + + # If we don't have additional inputs, we assume the first dimension + # is the dynamic batch dimension. Otherwise, we use the additional + # inputs to determine the batch dimension. + if additional_inputs is None: + return InputTensorSpec.from_tensors_with_dynamic_batch_size( + inputs, + ( + 0, + lower_setting.max_batch_size, + lower_setting.max_batch_size, + ), + lower_setting.opt_profile_replica, + ) + else: + batch_dims = [] + + for i, j in zip(inputs, additional_inputs): + found_batch_dim = False + + for idx, values in enumerate(zip(i.shape, j.shape)): + if values[0] != values[1]: + assert ( + found_batch_dim is False + ), f"We've already found a batch dim, {i.shape}, {j.shape}." + batch_dims.append(idx) + found_batch_dim = True + + if not found_batch_dim: + raise RuntimeError( + f"Failed to find batch dimension because shapes are the same, {i.shape}" + ) + + return InputTensorSpec.from_tensors_with_dynamic_batch_size( + inputs, + ( + 0, + lower_setting.max_batch_size, + lower_setting.max_batch_size, + ), + lower_setting.opt_profile_replica, + batch_dims, + ) + + class InputTensorSpec(NamedTuple): """ This class contains the information of a input tensor. @@ -70,6 +125,7 @@ def from_tensors_with_dynamic_batch_size( tensors: Sequence[torch.Tensor], batch_size_range: Tuple[int, int, int], opt_profile_replica: int = 1, + batch_dims: Optional[List[int]] = None, ) -> List["InputTensorSpec"]: """ Produce a list of InputTenosrSpec named tuples which would contain @@ -83,20 +139,30 @@ def from_tensors_with_dynamic_batch_size( the smallest batch size allowed. The second integer indiceates the batch size that we'll optimize for. The third integer indicates the largest batch size allowed. + opt_profile_replica (int): If dynamic shape is enabled, each execution + context requires a different optimization profile. This arg determines + how many optimization profile replicas we want to produce. + batch_dims (Optional[List[int]]): The batch dim might not be the leading dim + and allow user to specify the batch dims using this arg. Default we treat + dim 0 as the batch dim. Returns: A list of InputTensorSpec named tuples with dynamic ranges. """ + if batch_dims is None: + batch_dims = [0] * len(tensors) + input_specs = [] - batch_size = tensors[0].size(0) + batch_size = tensors[0].size(batch_dims[0]) for i, tensor in enumerate(tensors): + batch_dim = batch_dims[i] assert batch_size == tensor.size( - 0 + batch_dim ), f"The {i}th tensor (shape: {tensor.shape}) doesn't have the correct batch size: {batch_size}." shape = list(tensor.shape) - shape[0] = -1 - shape_ranges: List[ShapeRange] = [tuple(tuple([bs] + shape[1:]) for bs in batch_size_range)] * opt_profile_replica # type: ignore[list-item] + shape[batch_dim] = -1 + shape_ranges: List[ShapeRange] = [tuple(tuple(shape[0:batch_dim] + [bs] + shape[batch_dim + 1 :]) for bs in batch_size_range)] * opt_profile_replica # type: ignore[list-item] input_specs.append( cls(tuple(shape), tensor.dtype, tensor.device, shape_ranges) ) diff --git a/py/torch_tensorrt/fx/lower.py b/py/torch_tensorrt/fx/lower.py index 9f8ec7865c..82791faf12 100644 --- a/py/torch_tensorrt/fx/lower.py +++ b/py/torch_tensorrt/fx/lower.py @@ -1,6 +1,6 @@ import dataclasses as dc import logging -from typing import Any, Callable, Sequence +from typing import Any, Callable, Optional, Sequence # @manual=//deeplearning/trt/python:py_tensorrt import tensorrt as trt @@ -10,15 +10,9 @@ from torch.fx.passes.splitter_base import SplitResult from .fx2trt import TRTInterpreter, TRTInterpreterResult -from .input_tensor_spec import InputTensorSpec from .lower_setting import LowerSetting from .passes.lower_pass_manager_builder import LowerPassManagerBuilder -from .passes.pass_utils import ( - chain_passes, - decorate_method, - PassFunc, - validate_inference, -) +from .passes.pass_utils import decorate_method, PassFunc, validate_inference from .tools.timing_cache_utils import TimingCacheManager from .tools.trt_splitter import TRTSplitter, TRTSplitterSetting @@ -91,25 +85,8 @@ def create(cls, lower_setting): return LowerTrtInterpreter(lower_setting, timing_cache_manager) def __call__(self, mod, input, split_name) -> TRTInterpreterResult: - input_specs_val = ( - self.lower_setting.input_specs - if self.lower_setting.input_specs - else ( - InputTensorSpec.from_tensors_with_dynamic_batch_size( - input, - ( - 0, - self.lower_setting.max_batch_size, - self.lower_setting.max_batch_size, - ), - self.lower_setting.opt_profile_replica, - ) - if self.lower_setting.explicit_batch_dimension - and self.lower_setting.dynamic_batch - else InputTensorSpec.from_tensors(input) - ) - ) - logger.info(f"{split_name=} {input_specs_val=}") + assert self.lower_setting.input_specs, "Can't find input specs for lowering!" + logger.info(f"{split_name=} {self.lower_setting.input_specs=}") # Prepare algorithm selector and timing_cache for TRTInterpreter algo_selector = None @@ -125,7 +102,7 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult: interpreter = TRTInterpreter( mod, - input_specs=input_specs_val, + input_specs=self.lower_setting.input_specs, explicit_batch_dimension=self.lower_setting.explicit_batch_dimension, explicit_precision=self.lower_setting.explicit_precision, logger_level=trt.Logger.VERBOSE @@ -242,6 +219,7 @@ def __call__( self, module: nn.Module, inputs: Input, + additional_inputs: Optional[Input] = None, ) -> nn.Module: module.eval() @@ -254,7 +232,9 @@ def __call__( x.half() if x is not None and x.dtype == torch.float32 else x for x in inputs ) - pm = self.lower_pass_manager_builder.build_lower_pipeline(inputs) + pm = self.lower_pass_manager_builder.build_lower_pipeline( + inputs, additional_inputs + ) lower_result = pm(module) diff --git a/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py b/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py index 0f8e2233a2..ef39fd3bf7 100644 --- a/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py +++ b/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py @@ -1,11 +1,13 @@ from functools import partial, wraps -from typing import Any, Callable, Sequence +from typing import Any, Callable, Optional, Sequence import torch from torch import nn from torch.fx.passes.pass_manager import inplace_wrapper, PassManager from torch.fx.passes.shape_prop import ShapeProp -from torch.fx.passes.splitter_base import SplitResult +from torch.fx.passes.splitter_base import generate_inputs_for_submodules, SplitResult + +from ..input_tensor_spec import generate_input_specs from ..lower_setting import LowerSetting from ..observer import Observer @@ -120,6 +122,19 @@ def _split_pass(self) -> PassManager: def _lower_pass(self) -> PassManager: def lower_func(split_result: SplitResult) -> nn.Module: + if ( + hasattr(self.lower_setting, "explicit_batch_dimension") + and self.lower_setting.explicit_batch_dimension + and self._additional_input + ): + additional_submodule_inputs = generate_inputs_for_submodules( + split_result.split_module, + self._additional_input, + list(split_result.submodule_inputs.keys()), + ) + else: + additional_submodule_inputs = None + for submod_name, submod_inputs in split_result.submodule_inputs.items(): submod = getattr(split_result.split_module, submod_name) @@ -127,6 +142,13 @@ def lower_func(split_result: SplitResult) -> nn.Module: # Only acc submodules will be lowered. if not submod_name.startswith(split_result.non_acc_submodule_prefix): + self.lower_setting.input_specs = generate_input_specs( + submod_inputs, + self.lower_setting, + additional_submodule_inputs[submod_name] + if additional_submodule_inputs + else None, + ) lowered_module = self._lower_func( submod, submod_inputs, self.lower_setting, submod_name ) @@ -139,8 +161,11 @@ def lower_func(split_result: SplitResult) -> nn.Module: return PassManager.build_from_passlist([lower_func]) - def build_lower_pipeline(self, input: Input) -> PassManager: + def build_lower_pipeline( + self, input: Input, additional_input: Optional[Input] = None + ) -> PassManager: self._input = input + self._additional_input = additional_input passes = [] passes.append(self._const_fold_pass()) diff --git a/py/torch_tensorrt/fx/passes/pass_utils.py b/py/torch_tensorrt/fx/passes/pass_utils.py index 2018433599..b075b744bc 100644 --- a/py/torch_tensorrt/fx/passes/pass_utils.py +++ b/py/torch_tensorrt/fx/passes/pass_utils.py @@ -41,10 +41,13 @@ def _validate_inference(pass_: PassFunc) -> PassFunc: @wraps(pass_) def pass_with_validation( - module: fx.GraphModule, input: Input + module: fx.GraphModule, + input: Input, + *args, + **kwargs, ) -> fx.GraphModule: res0 = module(*input) - processed_module = pass_(module, input) + processed_module = pass_(module, input, *args, **kwargs) res1 = processed_module(*input) tensor_res_0 = _collect_tensors(res0) diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_prod.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_prod.py index b3ad1bcd12..26e4332fdc 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_prod.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_prod.py @@ -2,7 +2,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase +from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec # NOTE torch.prod will only accept one dim unlike other reduce ops which accept tuples @@ -93,6 +93,26 @@ def forward(self, x): test_implicit_batch_dim=False, ) + def test_prod_all_dims_with_dynamic_shape( + self, + op=torch.prod, + ): + class Prod(torch.nn.Module): + def forward(self, x): + return op(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], + ), + ] + + self.run_test_with_dynamic_shape( + Prod(), input_specs, expected_ops={acc_ops.prod} + ) + if __name__ == "__main__": run_tests() diff --git a/py/torch_tensorrt/fx/test/core/test_input_tensor_spec.py b/py/torch_tensorrt/fx/test/core/test_input_tensor_spec.py index cec49cb400..db848eaf1c 100644 --- a/py/torch_tensorrt/fx/test/core/test_input_tensor_spec.py +++ b/py/torch_tensorrt/fx/test/core/test_input_tensor_spec.py @@ -4,7 +4,7 @@ import torch from torch.testing._internal.common_utils import run_tests, TestCase -from torch_tensorrt.fx import InputTensorSpec +from torch_tensorrt.fx import generate_input_specs, InputTensorSpec, LowerSetting class TestTRTModule(TestCase): @@ -47,6 +47,47 @@ def test_from_tensors_with_dynamic_batch_size(self): self.assertEqual(batch_size, shape[0]) self.assertSequenceEqual(tensor.shape[1:], shape[1:]) + def test_from_tensors_with_dynamic_batch_size_different_batch_dims(self): + tensors = [torch.randn(1, 2, 3), torch.randn(2, 1, 4)] + batch_size_range = [2, 3, 4] + specs = InputTensorSpec.from_tensors_with_dynamic_batch_size( + tensors, batch_size_range, batch_dims=[0, 1] + ) + for i, spec_and_tensor in enumerate(zip(specs, tensors)): + spec, tensor = spec_and_tensor + self._validate_spec(spec, tensor, dynamic_dims=[i]) + + for batch_size, shape in zip(batch_size_range, spec.shape_ranges[0]): + self.assertEqual(batch_size, shape[i]) + tensor_shape = list(tensor.shape) + tensor_shape[i] = batch_size + self.assertSequenceEqual(tensor_shape, shape) + + def test_generate_input_specs(self): + lower_setting = LowerSetting( + explicit_batch_dimension=False, max_batch_size=256, opt_profile_replica=2 + ) + + # Implicit batch dim. + inputs = [torch.randn(1, 2, 3)] + specs = generate_input_specs(inputs, lower_setting) + for spec, tensor in zip(specs, inputs): + self._validate_spec(spec, tensor) + + # Explicit batch dim without additional inputs. + lower_setting.explicit_batch_dimension = True + specs = generate_input_specs(inputs, lower_setting) + for spec, tensor in zip(specs, inputs): + self._validate_spec(spec, tensor, dynamic_dims=[0]) + self.assertEqual(len(spec.shape_ranges), lower_setting.opt_profile_replica) + + # Explicit batch dim with additional inputs. + additional_inputs = [torch.randn(1, 1, 3)] + specs = generate_input_specs(inputs, lower_setting, additional_inputs) + for spec, tensor in zip(specs, inputs): + self._validate_spec(spec, tensor, dynamic_dims=[1]) + self.assertEqual(len(spec.shape_ranges), lower_setting.opt_profile_replica) + if __name__ == "__main__": run_tests()