diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index f7ebf9e9c5..1eedf4e406 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -691,6 +691,29 @@ def aten_ops_chunk( ) +@dynamo_tensorrt_converter(torch.ops.aten.cumsum.default) # type: ignore[misc] +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) # type: ignore[misc] +def aten_ops_cumsum( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.slice.cumsum( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + ) + + @dynamo_tensorrt_converter(torch.ops.aten.permute.default) # type: ignore[misc] @enforce_tensor_types( { diff --git a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py index e79d018636..19c5278137 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py @@ -1,10 +1,16 @@ import math from typing import Optional +import numpy as np +import tensorrt as trt from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion import impl from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext -from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim +from torch_tensorrt.dynamo.conversion.converter_utils import ( + get_positive_dim, + get_trt_tensor, +) from torch_tensorrt.dynamo.conversion.impl.slice.base import slice from torch_tensorrt.fx.converters.converter_utils import ( has_dynamic_shape, @@ -157,3 +163,43 @@ def chunk( cnt += 1 return result + + +def cumsum( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + dim: int, +) -> TRTTensor: + input_shape = input.shape + dim = get_positive_dim(dim, len(input_shape)) + loop = ctx.net.add_loop() + axis = np.array(input_shape[dim]) + trip_limit = get_trt_tensor(ctx, axis, f"{name}_trip_limit") + loop.add_trip_limit(trip_limit, trt.TripLimit.COUNT) + iterator = loop.add_iterator(input, dim, reverse=False) + data = iterator.get_output(0) + new_dims = tuple(data.shape) + zeros = np.zeros(new_dims) + zero_trttensor = get_trt_tensor(ctx, zeros, f"{name}_initial_value") + + running_sum = loop.add_recurrence(zero_trttensor) + set_layer_name(running_sum, target, f"{name}_running_sum", source_ir) + running_sum_tensor = running_sum.get_output(0) + + current_sum = impl.elementwise.add( + ctx, + target, + source_ir, + f"{name}_elementwise_add", + data, + running_sum_tensor, + ) + running_sum.set_input(1, current_sum) + + loop_output = loop.add_loop_output(current_sum, trt.LoopOutput.CONCATENATE, dim) + set_layer_name(loop_output, target, f"{name}_loop_output", source_ir) + loop_output.set_input(1, trip_limit) + return loop_output.get_output(0) diff --git a/tests/py/dynamo/conversion/test_cumsum_aten.py b/tests/py/dynamo/conversion/test_cumsum_aten.py new file mode 100644 index 0000000000..e6aaea56cb --- /dev/null +++ b/tests/py/dynamo/conversion/test_cumsum_aten.py @@ -0,0 +1,69 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestCumsumConverter(DispatchTestCase): + @parameterized.expand( + [ + ((1,), 0), + ((2,), 0), + ((3,), -1), + ] + ) + def test_cumsum_1D(self, shape, dim): + class Cumsum(nn.Module): + def forward(self, x): + return torch.ops.aten.cumsum.default(x, dim) + + inputs = [torch.randn(shape)] + self.run_test( + Cumsum(), + inputs, + ) + + @parameterized.expand( + [ + ((3, 1), 0), + ((3, 1), 1), + ((2, 3), -1), + ((2, 3), -2), + ] + ) + def test_cumsum_2D(self, shape, dims): + class Cumsum(nn.Module): + def forward(self, x): + return torch.ops.aten.cumsum.default(x, dims) + + inputs = [torch.randn(shape)] + self.run_test( + Cumsum(), + inputs, + ) + + @parameterized.expand( + [ + ((4, 2, 3), 0), + ((4, 2, 3), 1), + ((1, 2, 3), 2), + ((1, 2, 3), -1), + ((1, 2, 3), -2), + ] + ) + def test_cumsum_3D(self, shape, dims): + class Cumsum(nn.Module): + def forward(self, x): + return torch.ops.aten.cumsum.default(x, dims) + + inputs = [torch.randn(shape)] + self.run_test( + Cumsum(), + inputs, + ) + + +if __name__ == "__main__": + run_tests()