From 0ce925f3da7cc1c2bd3ae0fabc592da9b682336d Mon Sep 17 00:00:00 2001 From: Evan Li Date: Tue, 2 Jan 2024 00:43:47 -0800 Subject: [PATCH 1/2] feat: support aten.roll dynamo converter --- .../dynamo/conversion/aten_ops_converters.py | 24 +++++++ .../dynamo/conversion/impl/permutation.py | 62 ++++++++++++++++++- tests/py/dynamo/conversion/test_roll_aten.py | 42 +++++++++++++ 3 files changed, 126 insertions(+), 2 deletions(-) create mode 100644 tests/py/dynamo/conversion/test_roll_aten.py diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 66fc734e50..2274ee8915 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2683,3 +2683,27 @@ def aten_ops_scalar_tensor( return impl.unary.scalar_tensor( ctx, target, SourceIR.ATEN, name, args[0], dtype=kwargs.get("dtype") ) + + +@dynamo_tensorrt_converter(torch.ops.aten.roll.default) +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) +def aten_ops_roll( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.permutation.roll( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + args_bounds_check(args, 2, []), + ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/permutation.py b/py/torch_tensorrt/dynamo/conversion/impl/permutation.py index bdd9b46314..a0032a1a64 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/permutation.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/permutation.py @@ -1,11 +1,16 @@ from typing import Optional, Sequence +import tensorrt as trt from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion import impl from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext -from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim +from torch_tensorrt.dynamo.conversion.converter_utils import ( + flatten_dims, + get_positive_dim, +) from torch_tensorrt.fx.converters.converter_utils import set_layer_name -from torch_tensorrt.fx.types import TRTTensor +from torch_tensorrt.fx.types import Shape, TRTTensor def permute( @@ -27,3 +32,56 @@ def permute( layer.second_transpose = tuple(permutation) set_layer_name(layer, target, name, source_ir) return layer.get_output(0) + + +def roll( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + shifts: Shape, + dims: Shape, +) -> TRTTensor: + shape = input.shape + if dims != []: + rank = len(shape) + start = [0] * rank + stride = [1] * rank + for i in range(len(dims)): + d = dims[i] + s = shifts[i] + start[d] += get_positive_dim( + -s, shape[d] + ) # in case that dims has multiple same dim + + layer = ctx.net.add_slice( + input, + start=start, + shape=shape, + stride=stride, + ) + layer.mode = trt.SliceMode.WRAP + set_layer_name(layer, target, f"{name}_slice_wrap", source_ir) + return layer.get_output(0) + + else: + flatten_shape = flatten_dims(input, 0, -1) + output = impl.shuffle.reshape( + ctx, target, source_ir, f"{name}_reshape", input, flatten_shape + ) + start = [get_positive_dim(-shifts[0], output.shape[0])] + stride = [1] + layer = ctx.net.add_slice( + output, + start=start, + shape=flatten_shape, + stride=stride, + ) + layer.mode = trt.SliceMode.WRAP + set_layer_name(layer, target, f"{name}_slice_wrap", source_ir) + output = layer.get_output(0) + output = impl.shuffle.reshape( + ctx, target, source_ir, f"{name}_reshape_back", output, shape + ) + return output diff --git a/tests/py/dynamo/conversion/test_roll_aten.py b/tests/py/dynamo/conversion/test_roll_aten.py new file mode 100644 index 0000000000..3bbd2c4ad9 --- /dev/null +++ b/tests/py/dynamo/conversion/test_roll_aten.py @@ -0,0 +1,42 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input + +from .harness import DispatchTestCase + + +class TestRollConverter(DispatchTestCase): + @parameterized.expand( + [ + ((4,), [2], [0]), + ((4,), [3], [0]), + ((4,), [-3, 2], [0, 0]), + ((4,), [-2], []), + ((4, 2), [2, 1], [0, 1]), + ((3, 3), [2, 1], [1, 1]), + ((4, 2), [2, -1], [-2, -1]), + ((4, 2), [4], []), + ((3, 4, 2), [1, 0, 2], [2, 0, -2]), + ((3, 4, 2), [1, -0, 2], [1, 1, 1]), + ( + (3, 4, 2), + [ + 5, + ], + [], + ), + ] + ) + def test_roll_list(self, shape, shifts, dims): + class Roll(nn.Module): + def forward(self, x): + return torch.ops.aten.roll.default(x, shifts, dims) + + inputs = [torch.randn(shape)] + self.run_test(Roll(), inputs) + + +if __name__ == "__main__": + run_tests() From a10ac64dd1cf52a98a0c3a6649e3d7a675275b0e Mon Sep 17 00:00:00 2001 From: Evan Li Date: Fri, 19 Jan 2024 03:26:57 -0800 Subject: [PATCH 2/2] add integer dims in test --- .../dynamo/conversion/impl/permutation.py | 13 +++++++++---- tests/py/dynamo/conversion/test_roll_aten.py | 3 ++- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/permutation.py b/py/torch_tensorrt/dynamo/conversion/impl/permutation.py index a0032a1a64..48a91faa40 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/permutation.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/permutation.py @@ -1,4 +1,4 @@ -from typing import Optional, Sequence +from typing import Optional, Sequence, Union import tensorrt as trt from torch.fx.node import Target @@ -10,7 +10,7 @@ get_positive_dim, ) from torch_tensorrt.fx.converters.converter_utils import set_layer_name -from torch_tensorrt.fx.types import Shape, TRTTensor +from torch_tensorrt.fx.types import TRTTensor def permute( @@ -40,10 +40,15 @@ def roll( source_ir: Optional[SourceIR], name: str, input: TRTTensor, - shifts: Shape, - dims: Shape, + shifts: Union[int, Sequence[int]], + dims: Union[int, Sequence[int]], ) -> TRTTensor: shape = input.shape + if isinstance(shifts, int): + shifts = [shifts] + if isinstance(dims, int): + dims = [dims] + if dims != []: rank = len(shape) start = [0] * rank diff --git a/tests/py/dynamo/conversion/test_roll_aten.py b/tests/py/dynamo/conversion/test_roll_aten.py index 3bbd2c4ad9..80e9020855 100644 --- a/tests/py/dynamo/conversion/test_roll_aten.py +++ b/tests/py/dynamo/conversion/test_roll_aten.py @@ -10,6 +10,7 @@ class TestRollConverter(DispatchTestCase): @parameterized.expand( [ + ((4,), (2,), 0), ((4,), [2], [0]), ((4,), [3], [0]), ((4,), [-3, 2], [0, 0]), @@ -29,7 +30,7 @@ class TestRollConverter(DispatchTestCase): ), ] ) - def test_roll_list(self, shape, shifts, dims): + def test_roll(self, shape, shifts, dims): class Roll(nn.Module): def forward(self, x): return torch.ops.aten.roll.default(x, shifts, dims)