diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 66fc734e50..2274ee8915 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2683,3 +2683,27 @@ def aten_ops_scalar_tensor( return impl.unary.scalar_tensor( ctx, target, SourceIR.ATEN, name, args[0], dtype=kwargs.get("dtype") ) + + +@dynamo_tensorrt_converter(torch.ops.aten.roll.default) +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) +def aten_ops_roll( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.permutation.roll( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + args_bounds_check(args, 2, []), + ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/permutation.py b/py/torch_tensorrt/dynamo/conversion/impl/permutation.py index bdd9b46314..48a91faa40 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/permutation.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/permutation.py @@ -1,9 +1,14 @@ -from typing import Optional, Sequence +from typing import Optional, Sequence, Union +import tensorrt as trt from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion import impl from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext -from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim +from torch_tensorrt.dynamo.conversion.converter_utils import ( + flatten_dims, + get_positive_dim, +) from torch_tensorrt.fx.converters.converter_utils import set_layer_name from torch_tensorrt.fx.types import TRTTensor @@ -27,3 +32,61 @@ def permute( layer.second_transpose = tuple(permutation) set_layer_name(layer, target, name, source_ir) return layer.get_output(0) + + +def roll( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + shifts: Union[int, Sequence[int]], + dims: Union[int, Sequence[int]], +) -> TRTTensor: + shape = input.shape + if isinstance(shifts, int): + shifts = [shifts] + if isinstance(dims, int): + dims = [dims] + + if dims != []: + rank = len(shape) + start = [0] * rank + stride = [1] * rank + for i in range(len(dims)): + d = dims[i] + s = shifts[i] + start[d] += get_positive_dim( + -s, shape[d] + ) # in case that dims has multiple same dim + + layer = ctx.net.add_slice( + input, + start=start, + shape=shape, + stride=stride, + ) + layer.mode = trt.SliceMode.WRAP + set_layer_name(layer, target, f"{name}_slice_wrap", source_ir) + return layer.get_output(0) + + else: + flatten_shape = flatten_dims(input, 0, -1) + output = impl.shuffle.reshape( + ctx, target, source_ir, f"{name}_reshape", input, flatten_shape + ) + start = [get_positive_dim(-shifts[0], output.shape[0])] + stride = [1] + layer = ctx.net.add_slice( + output, + start=start, + shape=flatten_shape, + stride=stride, + ) + layer.mode = trt.SliceMode.WRAP + set_layer_name(layer, target, f"{name}_slice_wrap", source_ir) + output = layer.get_output(0) + output = impl.shuffle.reshape( + ctx, target, source_ir, f"{name}_reshape_back", output, shape + ) + return output diff --git a/tests/py/dynamo/conversion/test_roll_aten.py b/tests/py/dynamo/conversion/test_roll_aten.py new file mode 100644 index 0000000000..80e9020855 --- /dev/null +++ b/tests/py/dynamo/conversion/test_roll_aten.py @@ -0,0 +1,43 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input + +from .harness import DispatchTestCase + + +class TestRollConverter(DispatchTestCase): + @parameterized.expand( + [ + ((4,), (2,), 0), + ((4,), [2], [0]), + ((4,), [3], [0]), + ((4,), [-3, 2], [0, 0]), + ((4,), [-2], []), + ((4, 2), [2, 1], [0, 1]), + ((3, 3), [2, 1], [1, 1]), + ((4, 2), [2, -1], [-2, -1]), + ((4, 2), [4], []), + ((3, 4, 2), [1, 0, 2], [2, 0, -2]), + ((3, 4, 2), [1, -0, 2], [1, 1, 1]), + ( + (3, 4, 2), + [ + 5, + ], + [], + ), + ] + ) + def test_roll(self, shape, shifts, dims): + class Roll(nn.Module): + def forward(self, x): + return torch.ops.aten.roll.default(x, shifts, dims) + + inputs = [torch.randn(shape)] + self.run_test(Roll(), inputs) + + +if __name__ == "__main__": + run_tests()