From 6cf1d676863c98ab08a24b3561855bbe60c8a9c9 Mon Sep 17 00:00:00 2001 From: apbose Date: Tue, 15 Aug 2023 15:49:08 -0700 Subject: [PATCH 1/7] aten::split converter --- .../dynamo/conversion/aten_ops_converters.py | 16 +++ .../dynamo/conversion/impl/__init__.py | 1 + .../dynamo/conversion/impl/split.py | 82 +++++++++++++ tests/py/dynamo/converters/test_split_aten.py | 113 ++++++++++++++++++ 4 files changed, 212 insertions(+) create mode 100644 py/torch_tensorrt/dynamo/conversion/impl/split.py create mode 100644 tests/py/dynamo/converters/test_split_aten.py diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index dac526c7e0..3ab88c1b23 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -354,6 +354,22 @@ def aten_ops_softmax( ) +@dynamo_tensorrt_converter( + torch.ops.aten.split.default, capability_validator=dynamic_unsupported +) +@dynamo_tensorrt_converter( + torch.ops.aten.split.sizes, capability_validator=dynamic_unsupported +) +def aten_ops_split( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.split(network, target, SourceIR.ATEN, name, args[0], args[1], args[2]) + + @dynamo_tensorrt_converter(torch.ops.aten.where.self) # type: ignore[misc] def aten_ops_where( network: TRTNetwork, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py index db7c877e8f..e615599eb4 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py @@ -15,6 +15,7 @@ select, shape, slice, + split, squeeze, unary, unsqueeze, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/split.py b/py/torch_tensorrt/dynamo/conversion/impl/split.py new file mode 100644 index 0000000000..775270e241 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/split.py @@ -0,0 +1,82 @@ +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast + +import numpy as np +import torch +import torch_tensorrt as trt +from torch import Tensor +from torch.fx.node import Target +from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape +from torch_tensorrt.fx.converters.converter_utils import ( + has_dynamic_shape, + set_layer_name, +) +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor + + +def split( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + split_size_or_sections: Union[int, List(int)], + dim: Optional[Any] = 0, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + if not isinstance(input, TRTTensor): + raise RuntimeError( + f"split received input {input} that is not part " "of the TensorRT region!" + ) + + dim = cast(int, dim) + dynamic_shape = has_dynamic_shape(input.shape) + if network.has_implicit_batch_dimension: + assert dim != 0, "Can't split on batch dim when it's implicit!" + dim -= 1 + else: + if dynamic_shape > 0: + # Check whether slice target dim is dynamic shape dim + assert input.shape[dim] != -1, "Can't chunk on dynamic shape dimension!" + + split_sizes = [] + if type(split_size_or_sections) == int: + split_sizes.append(cast(int, split_size_or_sections)) + else: + for split_size_or_section in split_size_or_sections: + split_sizes.append(cast(int, split_size_or_section)) + + start = [0] * len(input.shape) + stride = [1] * len(start) + offset = 0 + + if len(split_sizes) == 1: + num_splits = input.shape[dim] + split_sizes[0] - 1 // split_sizes[0] + split_sizes = [split_sizes[0]] * num_splits + else: + num_splits = len(split_sizes) + + if num_splits < 1: + raise RuntimeError( + f"Invalid split: {input.shape[dim]} with split_size={split_sizes}" + ) + + max_offset = input.shape[dim] + # add slice layers + output = [] + for i in range(num_splits): + shape = list(input.shape) + shape[dim] = min(split_sizes[i], cast(int, max_offset - offset)) + start[dim] = offset + if dynamic_shape: + shape = get_shape_with_dynamic_shape( + network, shape, input, target, f"{name}_shape_{i}" + ) + layer = network.add_slice( + input, start=start, shape=[] if dynamic_shape else shape, stride=stride + ) + if dynamic_shape: + layer.set_input(2, shape) + offset += split_sizes[i] + set_layer_name(layer, target, f"{name}_{i}") + output.append(layer.get_output(0)) + return output diff --git a/tests/py/dynamo/converters/test_split_aten.py b/tests/py/dynamo/converters/test_split_aten.py new file mode 100644 index 0000000000..419a800a43 --- /dev/null +++ b/tests/py/dynamo/converters/test_split_aten.py @@ -0,0 +1,113 @@ +import torch +from harness import DispatchTestCase +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input + + +# FIXME: check about implicit and explicit batch +class TestSplitConverterNoDim(DispatchTestCase): + @parameterized.expand( + [ + ("split_size_or_sections_no_dim", 2), + ("split_size_or_sections_list_no_dim", [1, 4]), + ("split_size_or_sections_list_no_dim_not_full_split", [1, 3]), + ] + ) + def test_split(self, _, split_size_or_tensor): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + out = torch.split(input, split_size_or_tensor) + return out + + input = torch.arange(10).reshape(5, 2) + self.run_test( + TestModule(), + input, + expected_ops={torch.ops.aten.split.default}, + ) + + +class TestSplitConverterWithDim(DispatchTestCase): + @parameterized.expand( + [ + ("split_size_or_sections_dim", 2, 1), + ("split_size_or_sections_list_dim", [1, 4], 1), + ("split_size_or_sections_list_dim_not_full_split", [1, 3], 1), + ] + ) + def test_split(self, _, split_size_or_tensor, dim): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + out = torch.split(split_size_or_tensor, dim) + return out + + input = torch.arange(10).reshape(2, 5) + self.run_test( + TestModule(), + input, + expected_ops={torch.ops.aten.split.default}, + ) + + +class TestSplitConverterDynamicShape(DispatchTestCase): + @parameterized.expand( + [ + ("select_split_size_or_sections_dim", 2, 1), + ("select_split_size_or_sections_list_dim", [1, 4], 1), + ] + ) + def test_split(self, _, split_size_or_tensor, dim): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + out = torch.split(input, split_size_or_tensor, dim) + return out + + input_specs = [ + Input( + shape=(1, 10, -1), + dtype=torch.float32, + shape_ranges=[((1, 10, 1), (1, 10, 10), (1, 10, 10))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), + input_specs, + expected_ops={torch.ops.aten.split.default}, + ) + + +class TestSplitSymIntConverterImplicitBatch(DispatchTestCase): + @parameterized.expand( + [ + ("select_chunk_dim", 6, 0), + ] + ) + def test_chunk(self, _, chunk, dim): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + out = torch.ops.aten.chunk(input, chunk, dim) + return out + + input = [torch.randn(11)] + self.run_test( + TestModule(), + input, + expected_ops={torch.ops.aten.split.default}, + ) + + +if __name__ == "__main__": + run_tests() From 3f4a2e8045dde05522f4535f62b93761ec19a5fa Mon Sep 17 00:00:00 2001 From: apbose Date: Fri, 18 Aug 2023 10:06:45 -0700 Subject: [PATCH 2/7] Addressing review comments and checking in tests --- .../dynamo/conversion/aten_ops_converters.py | 31 +++- .../dynamo/conversion/impl/split.py | 26 +-- tests/py/dynamo/converters/test_split_aten.py | 155 +++++++++++++++--- 3 files changed, 171 insertions(+), 41 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 3ab88c1b23..0c7a5a6f91 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -352,13 +352,28 @@ def aten_ops_softmax( return impl.normalization.softmax( network, target, SourceIR.ATEN, name, args[0], args[1] ) + +def dynamic_unsupported_split(node: torch.fx.Node) -> bool: + # Validate that none of the inputs to the node have Dynamic shapes + assert isinstance( + node, torch.fx.Node + ), "Inputs to validator functions must be FX Nodes" + + if isinstance(node.args[1], torch.fx.Node): + if getattr(node.args[1].meta["val"], "_has_symbolic_sizes_strides", True): + return False + return True @dynamo_tensorrt_converter( - torch.ops.aten.split.default, capability_validator=dynamic_unsupported + torch.ops.aten.split.Tensor, capability_validator=dynamic_unsupported_split +) +@dynamo_tensorrt_converter( + torch.ops.aten.split.sizes, capability_validator=dynamic_unsupported_split ) @dynamo_tensorrt_converter( - torch.ops.aten.split.sizes, capability_validator=dynamic_unsupported + torch.ops.aten.split_with_sizes.default, + capability_validator=dynamic_unsupported_split, ) def aten_ops_split( network: TRTNetwork, @@ -366,8 +381,16 @@ def aten_ops_split( args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, -) -> Union[TRTTensor, Sequence[TRTTensor]]: - return impl.split(network, target, SourceIR.ATEN, name, args[0], args[1], args[2]) +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.split.split( + network, + target, + SourceIR.ATEN, + name, + input=args[0], + split_size_or_sections=args[1], + dim=args_bounds_check(args, 2, 0), + ) @dynamo_tensorrt_converter(torch.ops.aten.where.self) # type: ignore[misc] diff --git a/py/torch_tensorrt/dynamo/conversion/impl/split.py b/py/torch_tensorrt/dynamo/conversion/impl/split.py index 775270e241..658260228e 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/split.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/split.py @@ -20,8 +20,8 @@ def split( source_ir: Optional[SourceIR], name: str, input: TRTTensor, - split_size_or_sections: Union[int, List(int)], - dim: Optional[Any] = 0, + split_size_or_sections: Union[int, List[int]], + dim: int = 0, ) -> Union[TRTTensor, Sequence[TRTTensor]]: if not isinstance(input, TRTTensor): raise RuntimeError( @@ -30,16 +30,12 @@ def split( dim = cast(int, dim) dynamic_shape = has_dynamic_shape(input.shape) - if network.has_implicit_batch_dimension: - assert dim != 0, "Can't split on batch dim when it's implicit!" - dim -= 1 - else: - if dynamic_shape > 0: - # Check whether slice target dim is dynamic shape dim - assert input.shape[dim] != -1, "Can't chunk on dynamic shape dimension!" + if dynamic_shape > 0: + # Check whether slice target dim is dynamic shape dim + assert input.shape[dim] != -1, "Can't chunk on dynamic shape dimension!" split_sizes = [] - if type(split_size_or_sections) == int: + if isinstance(split_size_or_sections, int): split_sizes.append(cast(int, split_size_or_sections)) else: for split_size_or_section in split_size_or_sections: @@ -48,12 +44,16 @@ def split( start = [0] * len(input.shape) stride = [1] * len(start) offset = 0 - if len(split_sizes) == 1: - num_splits = input.shape[dim] + split_sizes[0] - 1 // split_sizes[0] + num_splits = (input.shape[dim] + split_sizes[0] - 1) // split_sizes[0] split_sizes = [split_sizes[0]] * num_splits else: num_splits = len(split_sizes) + sum_split_sizes = sum(split_sizes) + if sum_split_sizes != input.shape[dim]: + raise RuntimeError( + f"split sizes don't add up to the tensor's size in the given dimension" + ) if num_splits < 1: raise RuntimeError( @@ -69,7 +69,7 @@ def split( start[dim] = offset if dynamic_shape: shape = get_shape_with_dynamic_shape( - network, shape, input, target, f"{name}_shape_{i}" + network, target, source_ir, f"{name}_shape_{i}", shape, input ) layer = network.add_slice( input, start=start, shape=[] if dynamic_shape else shape, stride=stride diff --git a/tests/py/dynamo/converters/test_split_aten.py b/tests/py/dynamo/converters/test_split_aten.py index 419a800a43..44b6b0a4f0 100644 --- a/tests/py/dynamo/converters/test_split_aten.py +++ b/tests/py/dynamo/converters/test_split_aten.py @@ -3,6 +3,7 @@ from parameterized import parameterized from torch.testing._internal.common_utils import run_tests from torch_tensorrt import Input +from torch_tensorrt.dynamo.conversion import UnsupportedOperatorException # FIXME: check about implicit and explicit batch @@ -10,8 +11,6 @@ class TestSplitConverterNoDim(DispatchTestCase): @parameterized.expand( [ ("split_size_or_sections_no_dim", 2), - ("split_size_or_sections_list_no_dim", [1, 4]), - ("split_size_or_sections_list_no_dim_not_full_split", [1, 3]), ] ) def test_split(self, _, split_size_or_tensor): @@ -23,20 +22,62 @@ def forward(self, input): out = torch.split(input, split_size_or_tensor) return out - input = torch.arange(10).reshape(5, 2) + input = [torch.randn(10).reshape(5, 2)] self.run_test( TestModule(), input, - expected_ops={torch.ops.aten.split.default}, + expected_ops={torch.ops.aten.split.Tensor}, + disable_passes=True, ) + @parameterized.expand( + [ + ("split_size_or_sections_list_no_dim_list", [1, 4]), + ] + ) + def test_split_list(self, _, split_size_or_tensor): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + out = torch.split(input, split_size_or_tensor) + return out + + input = [torch.randn(10).reshape(5, 2)] + self.run_test( + TestModule(), + input, + expected_ops={torch.ops.aten.split_with_sizes.default}, + disable_passes=True, + ) -class TestSplitConverterWithDim(DispatchTestCase): @parameterized.expand( [ - ("split_size_or_sections_dim", 2, 1), - ("split_size_or_sections_list_dim", [1, 4], 1), - ("split_size_or_sections_list_dim_not_full_split", [1, 3], 1), + ("split_size_or_sections_list_no_dim_not_full_list", [1, 3]), + ] + ) + def test_split_not_full_list(self, _, split_size_or_tensor): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + out = torch.split(input, split_size_or_tensor) + return out + + input = [torch.randn(10).reshape(5, 2)] + with self.assertRaises(RuntimeError): + self.run_test( + TestModule(), + input, + expected_ops={torch.ops.aten.split_with_sizes.default}, + disable_passes=True, + ) + + @parameterized.expand( + [ + ("split_size_or_sections_dims", 2, 1), ] ) def test_split(self, _, split_size_or_tensor, dim): @@ -45,25 +86,90 @@ def __init__(self): super().__init__() def forward(self, input): - out = torch.split(split_size_or_tensor, dim) + out = torch.split(input, split_size_or_tensor, dim) + return out + + input = [torch.randn(10).reshape(5, 2)] + self.run_test( + TestModule(), + input, + expected_ops={torch.ops.aten.split.Tensor}, + disable_passes=True, + ) + + @parameterized.expand( + [ + ("split_size_or_sections_list_dims", [1, 1], 1), + ] + ) + def test_split_dim(self, _, split_size_or_tensor, dim): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + out = torch.split(input, split_size_or_tensor, dim) return out - input = torch.arange(10).reshape(2, 5) + input = [torch.randn(10).reshape(5, 2)] self.run_test( TestModule(), input, - expected_ops={torch.ops.aten.split.default}, + expected_ops={torch.ops.aten.split_with_sizes.default}, + disable_passes=True, ) + @parameterized.expand( + [ + ("split_size_or_sections_list_dims", [1, 1], 1), + ] + ) + def test_split_dim_list(self, _, split_size_or_tensor, dim): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + out = torch.split(input, split_size_or_tensor, dim) + return out + + input = [torch.randn(10).reshape(5, 2)] + self.run_test( + TestModule(), + input, + expected_ops={torch.ops.aten.split_with_sizes.default}, + disable_passes=True, + ) -class TestSplitConverterDynamicShape(DispatchTestCase): @parameterized.expand( [ - ("select_split_size_or_sections_dim", 2, 1), - ("select_split_size_or_sections_list_dim", [1, 4], 1), + ("split_size_or_sections_list_dims_not_full_list", [1, 1], 1), ] ) - def test_split(self, _, split_size_or_tensor, dim): + def test_split_dim_list(self, _, split_size_or_tensor, dim): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + out = torch.split(input, split_size_or_tensor, dim) + return out + + input = [torch.randn(15).reshape(5, 3)] + with self.assertRaises(RuntimeError): + self.run_test( + TestModule(), + input, + expected_ops={torch.ops.aten.split_with_sizes.default}, + disable_passes=True, + ) + + @parameterized.expand( + [ + ("select_split_size_or_sections_dim_dynamic_shape", 2, 1), + ] + ) + def test_split_dynamic(self, _, split_size_or_tensor, dim): class TestModule(torch.nn.Module): def __init__(self): super().__init__() @@ -82,17 +188,16 @@ def forward(self, input): self.run_test_with_dynamic_shape( TestModule(), input_specs, - expected_ops={torch.ops.aten.split.default}, + expected_ops={torch.ops.aten.split.Tensor}, + disable_passes=True, ) - -class TestSplitSymIntConverterImplicitBatch(DispatchTestCase): @parameterized.expand( [ ("select_chunk_dim", 6, 0), ] ) - def test_chunk(self, _, chunk, dim): + def test_split_dynamic(self, _, chunk, dim): class TestModule(torch.nn.Module): def __init__(self): super().__init__() @@ -102,11 +207,13 @@ def forward(self, input): return out input = [torch.randn(11)] - self.run_test( - TestModule(), - input, - expected_ops={torch.ops.aten.split.default}, - ) + with self.assertRaises(UnsupportedOperatorException): + self.run_test( + TestModule(), + input, + expected_ops={torch.ops.aten.split.Tensor}, + disable_passes=True, + ) if __name__ == "__main__": From e5b71204a41ffc29d96a6d50ab43cbc215418c75 Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Fri, 18 Aug 2023 14:19:14 -0700 Subject: [PATCH 3/7] feat/fix: Update dynamic unsupported implementation - Add support for selecting individual argument positions to check and expand checking to include symbolic types, which are sometimes passed in as arguments --- .../dynamo/conversion/aten_ops_converters.py | 19 ++----- .../dynamo/conversion/converter_utils.py | 50 +++++++++++++------ 2 files changed, 38 insertions(+), 31 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 0c7a5a6f91..4f63ae13df 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -352,28 +352,17 @@ def aten_ops_softmax( return impl.normalization.softmax( network, target, SourceIR.ATEN, name, args[0], args[1] ) - -def dynamic_unsupported_split(node: torch.fx.Node) -> bool: - # Validate that none of the inputs to the node have Dynamic shapes - assert isinstance( - node, torch.fx.Node - ), "Inputs to validator functions must be FX Nodes" - - if isinstance(node.args[1], torch.fx.Node): - if getattr(node.args[1].meta["val"], "_has_symbolic_sizes_strides", True): - return False - return True @dynamo_tensorrt_converter( - torch.ops.aten.split.Tensor, capability_validator=dynamic_unsupported_split + torch.ops.aten.split.Tensor, capability_validator=dynamic_unsupported_with_args([1]) ) @dynamo_tensorrt_converter( - torch.ops.aten.split.sizes, capability_validator=dynamic_unsupported_split + torch.ops.aten.split.sizes, capability_validator=dynamic_unsupported_with_args([1]) ) @dynamo_tensorrt_converter( torch.ops.aten.split_with_sizes.default, - capability_validator=dynamic_unsupported_split, + capability_validator=dynamic_unsupported_with_args([1]), ) def aten_ops_split( network: TRTNetwork, @@ -381,7 +370,7 @@ def aten_ops_split( args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, -) -> Union[TRTTensor, Sequence[TRTTensor]]: +) -> Union[TRTTensor, Sequence[TRTTensor]]: return impl.split.split( network, target, diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index 99cf2fa85a..c136668b77 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -1,11 +1,12 @@ import functools import logging import re -from typing import Any, List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Union, Callable import numpy as np import tensorrt as trt import torch +from torch import SymBool, SymFloat, SymInt from torch.fx.node import Target from torch_tensorrt.fx.converters.converter_utils import ( Frameworks, @@ -60,34 +61,51 @@ def is_only_operator_on_placeholder(node: torch.fx.Node) -> bool: def dynamic_unsupported(node: torch.fx.Node) -> bool: + """Validates that a node has no dynamic args, kwargs, or outputs""" + return _dynamic_unsupported(node=node) + + +def dynamic_unsupported_with_args( + arg_positions_to_check: Optional[List[int]] = None, +) -> Callable[[torch.fx.Node], bool]: + """Returns a validator that a node has no dynamic args at specific positions""" + return functools.partial(_dynamic_unsupported, arg_positions_to_check=arg_positions_to_check) + + +def _dynamic_unsupported( + node: torch.fx.Node, arg_positions_to_check: Optional[List[int]] = None +) -> bool: # Validate that none of the inputs to the node have Dynamic shapes assert isinstance( node, torch.fx.Node ), "Inputs to validator functions must be FX Nodes" + def _is_subnode_dynamic(subnode: torch.fx.Node) -> bool: + """Checks if a node itself has Dynamic properties""" + return getattr( + subnode.meta["val"], "_has_symbolic_sizes_strides", False + ) or isinstance(subnode.meta["val"], (SymFloat, SymInt, SymBool)) + # Check node value itself - if ("val" in node.meta) and getattr( - node.meta["val"], "_has_symbolic_sizes_strides", False - ): + if arg_positions_to_check is None and _is_subnode_dynamic(node): return False # Check node arguments individually - if any( - ( - ("val" in arg.meta) - and getattr(arg.meta["val"], "_has_symbolic_sizes_strides", False) - ) - for arg in node.args - if isinstance(arg, torch.fx.Node) + if arg_positions_to_check is None and any( + _is_subnode_dynamic(arg) for arg in node.args if isinstance(arg, torch.fx.Node) + ): + return False + # Check specific arg positions if the caller has specified positions to check + elif arg_positions_to_check is not None and any( + _is_subnode_dynamic(node.args[i]) + for i in arg_positions_to_check + if isinstance(node.args[i], torch.fx.Node) ): return False # Check node keyword arguments individually - if any( - ( - ("val" in kwarg.meta) - and getattr(kwarg.meta["val"], "_has_symbolic_sizes_strides", False) - ) + if arg_positions_to_check is None and any( + _is_subnode_dynamic(kwarg) for kwarg in node.kwargs.values() if isinstance(kwarg, torch.fx.Node) ): From 71bd68f0bfcdd86ed0b15b0cab4453484df583c7 Mon Sep 17 00:00:00 2001 From: apbose Date: Mon, 28 Aug 2023 16:45:58 -0700 Subject: [PATCH 4/7] combining split tests --- tests/py/dynamo/converters/test_split_aten.py | 46 +------------------ 1 file changed, 1 insertion(+), 45 deletions(-) diff --git a/tests/py/dynamo/converters/test_split_aten.py b/tests/py/dynamo/converters/test_split_aten.py index 44b6b0a4f0..af09b5db60 100644 --- a/tests/py/dynamo/converters/test_split_aten.py +++ b/tests/py/dynamo/converters/test_split_aten.py @@ -33,6 +33,7 @@ def forward(self, input): @parameterized.expand( [ ("split_size_or_sections_list_no_dim_list", [1, 4]), + ("split_size_or_sections_list_no_dim_not_full_list", [1, 3]), ] ) def test_split_list(self, _, split_size_or_tensor): @@ -52,29 +53,6 @@ def forward(self, input): disable_passes=True, ) - @parameterized.expand( - [ - ("split_size_or_sections_list_no_dim_not_full_list", [1, 3]), - ] - ) - def test_split_not_full_list(self, _, split_size_or_tensor): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, input): - out = torch.split(input, split_size_or_tensor) - return out - - input = [torch.randn(10).reshape(5, 2)] - with self.assertRaises(RuntimeError): - self.run_test( - TestModule(), - input, - expected_ops={torch.ops.aten.split_with_sizes.default}, - disable_passes=True, - ) - @parameterized.expand( [ ("split_size_or_sections_dims", 2, 1), @@ -97,28 +75,6 @@ def forward(self, input): disable_passes=True, ) - @parameterized.expand( - [ - ("split_size_or_sections_list_dims", [1, 1], 1), - ] - ) - def test_split_dim(self, _, split_size_or_tensor, dim): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, input): - out = torch.split(input, split_size_or_tensor, dim) - return out - - input = [torch.randn(10).reshape(5, 2)] - self.run_test( - TestModule(), - input, - expected_ops={torch.ops.aten.split_with_sizes.default}, - disable_passes=True, - ) - @parameterized.expand( [ ("split_size_or_sections_list_dims", [1, 1], 1), From 4fed932ebfd1da9a094842de727f984adb882a2c Mon Sep 17 00:00:00 2001 From: apbose Date: Fri, 1 Sep 2023 15:57:05 -0700 Subject: [PATCH 5/7] Removing cast --- py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py | 1 + py/torch_tensorrt/dynamo/conversion/converter_utils.py | 6 ++++-- py/torch_tensorrt/dynamo/conversion/impl/split.py | 7 +++---- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 4f63ae13df..19f273ba3f 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -11,6 +11,7 @@ from torch_tensorrt.fx.types import TRTNetwork, TRTTensor from .converter_registry import dynamo_tensorrt_converter +from .converter_utils import dynamic_unsupported_with_args _LOGGER: logging.Logger = logging.getLogger(__name__) diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index c136668b77..8e0c9e777a 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -1,7 +1,7 @@ import functools import logging import re -from typing import Any, List, Optional, Tuple, Union, Callable +from typing import Any, Callable, List, Optional, Tuple, Union import numpy as np import tensorrt as trt @@ -69,7 +69,9 @@ def dynamic_unsupported_with_args( arg_positions_to_check: Optional[List[int]] = None, ) -> Callable[[torch.fx.Node], bool]: """Returns a validator that a node has no dynamic args at specific positions""" - return functools.partial(_dynamic_unsupported, arg_positions_to_check=arg_positions_to_check) + return functools.partial( + _dynamic_unsupported, arg_positions_to_check=arg_positions_to_check + ) def _dynamic_unsupported( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/split.py b/py/torch_tensorrt/dynamo/conversion/impl/split.py index 658260228e..4f17f0860e 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/split.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/split.py @@ -28,7 +28,6 @@ def split( f"split received input {input} that is not part " "of the TensorRT region!" ) - dim = cast(int, dim) dynamic_shape = has_dynamic_shape(input.shape) if dynamic_shape > 0: # Check whether slice target dim is dynamic shape dim @@ -36,10 +35,10 @@ def split( split_sizes = [] if isinstance(split_size_or_sections, int): - split_sizes.append(cast(int, split_size_or_sections)) + split_sizes.append(split_size_or_sections) else: for split_size_or_section in split_size_or_sections: - split_sizes.append(cast(int, split_size_or_section)) + split_sizes.append(split_size_or_section) start = [0] * len(input.shape) stride = [1] * len(start) @@ -65,7 +64,7 @@ def split( output = [] for i in range(num_splits): shape = list(input.shape) - shape[dim] = min(split_sizes[i], cast(int, max_offset - offset)) + shape[dim] = min(split_sizes[i], max_offset - offset) start[dim] = offset if dynamic_shape: shape = get_shape_with_dynamic_shape( From 797d51044c085767504dfb8349215d4cb5d503ea Mon Sep 17 00:00:00 2001 From: apbose Date: Wed, 20 Sep 2023 17:59:45 -0700 Subject: [PATCH 6/7] Change in test_split location --- tests/py/dynamo/{converters => conversion}/test_split_aten.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename tests/py/dynamo/{converters => conversion}/test_split_aten.py (99%) diff --git a/tests/py/dynamo/converters/test_split_aten.py b/tests/py/dynamo/conversion/test_split_aten.py similarity index 99% rename from tests/py/dynamo/converters/test_split_aten.py rename to tests/py/dynamo/conversion/test_split_aten.py index af09b5db60..87a7761001 100644 --- a/tests/py/dynamo/converters/test_split_aten.py +++ b/tests/py/dynamo/conversion/test_split_aten.py @@ -1,5 +1,5 @@ import torch -from harness import DispatchTestCase +from .harness import DispatchTestCase from parameterized import parameterized from torch.testing._internal.common_utils import run_tests from torch_tensorrt import Input From e95e0e4fabfb7f894f0171927b9a5d7e36ad6b2e Mon Sep 17 00:00:00 2001 From: apbose Date: Wed, 20 Sep 2023 18:18:28 -0700 Subject: [PATCH 7/7] Removing incorrect test and removing cast from split impl --- py/torch_tensorrt/dynamo/conversion/impl/split.py | 2 +- tests/py/dynamo/conversion/test_split_aten.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/split.py b/py/torch_tensorrt/dynamo/conversion/impl/split.py index 4f17f0860e..1785e454e5 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/split.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/split.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import numpy as np import torch diff --git a/tests/py/dynamo/conversion/test_split_aten.py b/tests/py/dynamo/conversion/test_split_aten.py index 87a7761001..ffd8e145b9 100644 --- a/tests/py/dynamo/conversion/test_split_aten.py +++ b/tests/py/dynamo/conversion/test_split_aten.py @@ -33,7 +33,6 @@ def forward(self, input): @parameterized.expand( [ ("split_size_or_sections_list_no_dim_list", [1, 4]), - ("split_size_or_sections_list_no_dim_not_full_list", [1, 3]), ] ) def test_split_list(self, _, split_size_or_tensor):