Skip to content

feat: support aten.roll dynamo converter #2569

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2683,3 +2683,27 @@ def aten_ops_scalar_tensor(
return impl.unary.scalar_tensor(
ctx, target, SourceIR.ATEN, name, args[0], dtype=kwargs.get("dtype")
)


@dynamo_tensorrt_converter(torch.ops.aten.roll.default)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_roll(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.permutation.roll(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
args_bounds_check(args, 2, []),
)
67 changes: 65 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/impl/permutation.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from typing import Optional, Sequence
from typing import Optional, Sequence, Union

import tensorrt as trt
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim
from torch_tensorrt.dynamo.conversion.converter_utils import (
flatten_dims,
get_positive_dim,
)
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
from torch_tensorrt.fx.types import TRTTensor

Expand All @@ -27,3 +32,61 @@ def permute(
layer.second_transpose = tuple(permutation)
set_layer_name(layer, target, name, source_ir)
return layer.get_output(0)


def roll(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
shifts: Union[int, Sequence[int]],
dims: Union[int, Sequence[int]],
) -> TRTTensor:
shape = input.shape
if isinstance(shifts, int):
shifts = [shifts]
if isinstance(dims, int):
dims = [dims]

if dims != []:
rank = len(shape)
start = [0] * rank
stride = [1] * rank
for i in range(len(dims)):
d = dims[i]
s = shifts[i]
start[d] += get_positive_dim(
-s, shape[d]
) # in case that dims has multiple same dim

layer = ctx.net.add_slice(
input,
start=start,
shape=shape,
stride=stride,
)
layer.mode = trt.SliceMode.WRAP
set_layer_name(layer, target, f"{name}_slice_wrap", source_ir)
return layer.get_output(0)

else:
flatten_shape = flatten_dims(input, 0, -1)
output = impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_reshape", input, flatten_shape
)
start = [get_positive_dim(-shifts[0], output.shape[0])]
stride = [1]
layer = ctx.net.add_slice(
output,
start=start,
shape=flatten_shape,
stride=stride,
)
layer.mode = trt.SliceMode.WRAP
set_layer_name(layer, target, f"{name}_slice_wrap", source_ir)
output = layer.get_output(0)
output = impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_reshape_back", output, shape
)
return output
43 changes: 43 additions & 0 deletions tests/py/dynamo/conversion/test_roll_aten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import torch
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase


class TestRollConverter(DispatchTestCase):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding a case where both shifts and dims are single integers, which is a supported case in the docstring. These may be casted to lists in the operator before the converter ever gets them, but it is still a valid input I believe.

Copy link
Collaborator Author

@zewenli98 zewenli98 Jan 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your review! I found the schema is:

- func: roll(Tensor self, SymInt[1] shifts, int[1] dims=[]) -> Tensor

Does this mean shifts and dims should be a 1d list?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

roll(tensor, 2, 3) --> roll(tensor, [2], [3])
pool(3) --> pool([3, 3])
To share additional documentation

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gs-olive Thanks for the details!
Unfortunately, when testing shifts=2, dims=0, I got error:

File "<eval_with_key>.0 from /home/zewenl/TensorRT/tests/py/dynamo/conversion/test_roll_aten.py:35 in forward", line 5, in forward
    roll_default = torch.ops.aten.roll.default(x, 2, 0);  x = None
  File "/home/zewenl/.local/lib/python3.10/site-packages/torch/_ops.py", line 571, in __call__
    return self_._op(*args, **(kwargs or {}))
RuntimeError: aten::roll() Expected a value of type 'List[int]' for argument 'shifts' but instead found type 'int'.
Position: 1
Value: 2
Declaration: aten::roll(Tensor self, SymInt[1] shifts, int[1] dims=[]) -> Tensor
 Python error details: TypeError: 'int' object is not iterable

Then, I also tested shifts=(2,), dims=0, it works.

It seems that pytorch requires shifts to be a list.

According to the schema - func: roll(Tensor self, SymInt[1] shifts, int[1] dims=[]) -> Tensor, I guess SymInt[1] and int[1] may have different behaviors?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be so, yes, though that is strange - thanks for the update.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I'm working on adaptive_avg_pool2d(Tensor self, SymInt[2] output_size) -> Tensor converter, output_size expects List[int] as well:

RuntimeError: aten::adaptive_avg_pool2d() Expected a value of type 'List[int]' for argument 'output_size' but instead found type 'int'.
Position: 1
Value: 3
Declaration: aten::adaptive_avg_pool2d(Tensor self, SymInt[2] output_size) -> Tensor

@parameterized.expand(
[
((4,), (2,), 0),
((4,), [2], [0]),
((4,), [3], [0]),
((4,), [-3, 2], [0, 0]),
((4,), [-2], []),
((4, 2), [2, 1], [0, 1]),
((3, 3), [2, 1], [1, 1]),
((4, 2), [2, -1], [-2, -1]),
((4, 2), [4], []),
((3, 4, 2), [1, 0, 2], [2, 0, -2]),
((3, 4, 2), [1, -0, 2], [1, 1, 1]),
(
(3, 4, 2),
[
5,
],
[],
),
]
)
def test_roll(self, shape, shifts, dims):
class Roll(nn.Module):
def forward(self, x):
return torch.ops.aten.roll.default(x, shifts, dims)

inputs = [torch.randn(shape)]
self.run_test(Roll(), inputs)


if __name__ == "__main__":
run_tests()