diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index dac526c7e0..19f273ba3f 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -11,6 +11,7 @@ from torch_tensorrt.fx.types import TRTNetwork, TRTTensor from .converter_registry import dynamo_tensorrt_converter +from .converter_utils import dynamic_unsupported_with_args _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -354,6 +355,34 @@ def aten_ops_softmax( ) +@dynamo_tensorrt_converter( + torch.ops.aten.split.Tensor, capability_validator=dynamic_unsupported_with_args([1]) +) +@dynamo_tensorrt_converter( + torch.ops.aten.split.sizes, capability_validator=dynamic_unsupported_with_args([1]) +) +@dynamo_tensorrt_converter( + torch.ops.aten.split_with_sizes.default, + capability_validator=dynamic_unsupported_with_args([1]), +) +def aten_ops_split( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.split.split( + network, + target, + SourceIR.ATEN, + name, + input=args[0], + split_size_or_sections=args[1], + dim=args_bounds_check(args, 2, 0), + ) + + @dynamo_tensorrt_converter(torch.ops.aten.where.self) # type: ignore[misc] def aten_ops_where( network: TRTNetwork, diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index 99cf2fa85a..8e0c9e777a 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -1,11 +1,12 @@ import functools import logging import re -from typing import Any, List, Optional, Tuple, Union +from typing import Any, Callable, List, Optional, Tuple, Union import numpy as np import tensorrt as trt import torch +from torch import SymBool, SymFloat, SymInt from torch.fx.node import Target from torch_tensorrt.fx.converters.converter_utils import ( Frameworks, @@ -60,34 +61,53 @@ def is_only_operator_on_placeholder(node: torch.fx.Node) -> bool: def dynamic_unsupported(node: torch.fx.Node) -> bool: + """Validates that a node has no dynamic args, kwargs, or outputs""" + return _dynamic_unsupported(node=node) + + +def dynamic_unsupported_with_args( + arg_positions_to_check: Optional[List[int]] = None, +) -> Callable[[torch.fx.Node], bool]: + """Returns a validator that a node has no dynamic args at specific positions""" + return functools.partial( + _dynamic_unsupported, arg_positions_to_check=arg_positions_to_check + ) + + +def _dynamic_unsupported( + node: torch.fx.Node, arg_positions_to_check: Optional[List[int]] = None +) -> bool: # Validate that none of the inputs to the node have Dynamic shapes assert isinstance( node, torch.fx.Node ), "Inputs to validator functions must be FX Nodes" + def _is_subnode_dynamic(subnode: torch.fx.Node) -> bool: + """Checks if a node itself has Dynamic properties""" + return getattr( + subnode.meta["val"], "_has_symbolic_sizes_strides", False + ) or isinstance(subnode.meta["val"], (SymFloat, SymInt, SymBool)) + # Check node value itself - if ("val" in node.meta) and getattr( - node.meta["val"], "_has_symbolic_sizes_strides", False - ): + if arg_positions_to_check is None and _is_subnode_dynamic(node): return False # Check node arguments individually - if any( - ( - ("val" in arg.meta) - and getattr(arg.meta["val"], "_has_symbolic_sizes_strides", False) - ) - for arg in node.args - if isinstance(arg, torch.fx.Node) + if arg_positions_to_check is None and any( + _is_subnode_dynamic(arg) for arg in node.args if isinstance(arg, torch.fx.Node) + ): + return False + # Check specific arg positions if the caller has specified positions to check + elif arg_positions_to_check is not None and any( + _is_subnode_dynamic(node.args[i]) + for i in arg_positions_to_check + if isinstance(node.args[i], torch.fx.Node) ): return False # Check node keyword arguments individually - if any( - ( - ("val" in kwarg.meta) - and getattr(kwarg.meta["val"], "_has_symbolic_sizes_strides", False) - ) + if arg_positions_to_check is None and any( + _is_subnode_dynamic(kwarg) for kwarg in node.kwargs.values() if isinstance(kwarg, torch.fx.Node) ): diff --git a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py index db7c877e8f..e615599eb4 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py @@ -15,6 +15,7 @@ select, shape, slice, + split, squeeze, unary, unsqueeze, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/split.py b/py/torch_tensorrt/dynamo/conversion/impl/split.py new file mode 100644 index 0000000000..1785e454e5 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/split.py @@ -0,0 +1,81 @@ +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union + +import numpy as np +import torch +import torch_tensorrt as trt +from torch import Tensor +from torch.fx.node import Target +from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape +from torch_tensorrt.fx.converters.converter_utils import ( + has_dynamic_shape, + set_layer_name, +) +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor + + +def split( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + split_size_or_sections: Union[int, List[int]], + dim: int = 0, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + if not isinstance(input, TRTTensor): + raise RuntimeError( + f"split received input {input} that is not part " "of the TensorRT region!" + ) + + dynamic_shape = has_dynamic_shape(input.shape) + if dynamic_shape > 0: + # Check whether slice target dim is dynamic shape dim + assert input.shape[dim] != -1, "Can't chunk on dynamic shape dimension!" + + split_sizes = [] + if isinstance(split_size_or_sections, int): + split_sizes.append(split_size_or_sections) + else: + for split_size_or_section in split_size_or_sections: + split_sizes.append(split_size_or_section) + + start = [0] * len(input.shape) + stride = [1] * len(start) + offset = 0 + if len(split_sizes) == 1: + num_splits = (input.shape[dim] + split_sizes[0] - 1) // split_sizes[0] + split_sizes = [split_sizes[0]] * num_splits + else: + num_splits = len(split_sizes) + sum_split_sizes = sum(split_sizes) + if sum_split_sizes != input.shape[dim]: + raise RuntimeError( + f"split sizes don't add up to the tensor's size in the given dimension" + ) + + if num_splits < 1: + raise RuntimeError( + f"Invalid split: {input.shape[dim]} with split_size={split_sizes}" + ) + + max_offset = input.shape[dim] + # add slice layers + output = [] + for i in range(num_splits): + shape = list(input.shape) + shape[dim] = min(split_sizes[i], max_offset - offset) + start[dim] = offset + if dynamic_shape: + shape = get_shape_with_dynamic_shape( + network, target, source_ir, f"{name}_shape_{i}", shape, input + ) + layer = network.add_slice( + input, start=start, shape=[] if dynamic_shape else shape, stride=stride + ) + if dynamic_shape: + layer.set_input(2, shape) + offset += split_sizes[i] + set_layer_name(layer, target, f"{name}_{i}") + output.append(layer.get_output(0)) + return output diff --git a/tests/py/dynamo/conversion/test_split_aten.py b/tests/py/dynamo/conversion/test_split_aten.py new file mode 100644 index 0000000000..ffd8e145b9 --- /dev/null +++ b/tests/py/dynamo/conversion/test_split_aten.py @@ -0,0 +1,175 @@ +import torch +from .harness import DispatchTestCase +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input +from torch_tensorrt.dynamo.conversion import UnsupportedOperatorException + + +# FIXME: check about implicit and explicit batch +class TestSplitConverterNoDim(DispatchTestCase): + @parameterized.expand( + [ + ("split_size_or_sections_no_dim", 2), + ] + ) + def test_split(self, _, split_size_or_tensor): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + out = torch.split(input, split_size_or_tensor) + return out + + input = [torch.randn(10).reshape(5, 2)] + self.run_test( + TestModule(), + input, + expected_ops={torch.ops.aten.split.Tensor}, + disable_passes=True, + ) + + @parameterized.expand( + [ + ("split_size_or_sections_list_no_dim_list", [1, 4]), + ] + ) + def test_split_list(self, _, split_size_or_tensor): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + out = torch.split(input, split_size_or_tensor) + return out + + input = [torch.randn(10).reshape(5, 2)] + self.run_test( + TestModule(), + input, + expected_ops={torch.ops.aten.split_with_sizes.default}, + disable_passes=True, + ) + + @parameterized.expand( + [ + ("split_size_or_sections_dims", 2, 1), + ] + ) + def test_split(self, _, split_size_or_tensor, dim): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + out = torch.split(input, split_size_or_tensor, dim) + return out + + input = [torch.randn(10).reshape(5, 2)] + self.run_test( + TestModule(), + input, + expected_ops={torch.ops.aten.split.Tensor}, + disable_passes=True, + ) + + @parameterized.expand( + [ + ("split_size_or_sections_list_dims", [1, 1], 1), + ] + ) + def test_split_dim_list(self, _, split_size_or_tensor, dim): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + out = torch.split(input, split_size_or_tensor, dim) + return out + + input = [torch.randn(10).reshape(5, 2)] + self.run_test( + TestModule(), + input, + expected_ops={torch.ops.aten.split_with_sizes.default}, + disable_passes=True, + ) + + @parameterized.expand( + [ + ("split_size_or_sections_list_dims_not_full_list", [1, 1], 1), + ] + ) + def test_split_dim_list(self, _, split_size_or_tensor, dim): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + out = torch.split(input, split_size_or_tensor, dim) + return out + + input = [torch.randn(15).reshape(5, 3)] + with self.assertRaises(RuntimeError): + self.run_test( + TestModule(), + input, + expected_ops={torch.ops.aten.split_with_sizes.default}, + disable_passes=True, + ) + + @parameterized.expand( + [ + ("select_split_size_or_sections_dim_dynamic_shape", 2, 1), + ] + ) + def test_split_dynamic(self, _, split_size_or_tensor, dim): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + out = torch.split(input, split_size_or_tensor, dim) + return out + + input_specs = [ + Input( + shape=(1, 10, -1), + dtype=torch.float32, + shape_ranges=[((1, 10, 1), (1, 10, 10), (1, 10, 10))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), + input_specs, + expected_ops={torch.ops.aten.split.Tensor}, + disable_passes=True, + ) + + @parameterized.expand( + [ + ("select_chunk_dim", 6, 0), + ] + ) + def test_split_dynamic(self, _, chunk, dim): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + out = torch.ops.aten.chunk(input, chunk, dim) + return out + + input = [torch.randn(11)] + with self.assertRaises(UnsupportedOperatorException): + self.run_test( + TestModule(), + input, + expected_ops={torch.ops.aten.split.Tensor}, + disable_passes=True, + ) + + +if __name__ == "__main__": + run_tests()