diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index d0646e4bc6..d02466d4f5 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2541,3 +2541,25 @@ def aten_ops_sort( dim=args_bounds_check(args, 1, -1), descending=args_bounds_check(args, 2, False), ) + + +@dynamo_tensorrt_converter(torch.ops.aten.trunc.default) +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) +def aten_ops_trunc( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.unary.trunc( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py index 3a0fd47ac5..9ed5d0636d 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py @@ -5,7 +5,10 @@ from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext -from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor +from torch_tensorrt.dynamo.conversion.converter_utils import ( + cast_trt_tensor, + get_trt_tensor, +) from torch_tensorrt.dynamo.conversion.impl.unary.base import convert_unary from torch_tensorrt.fx.types import TRTTensor @@ -432,3 +435,27 @@ def erf( return convert_unary( ctx, target, source_ir, name, trt.UnaryOperation.ERF, input_val ) + + +def trunc( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_val: TRTTensor, +) -> TRTTensor: + if input_val.dtype not in (trt.float16, trt.float32): + return impl.cast.to_copy( + ctx, + target, + source_ir, + f"{name}_copy", + input_val, + input_val.dtype, + force_layer=True, + ) + + dividend = get_trt_tensor(ctx, 1, f"{name}_dividend") + return impl.elementwise.trunc_div( + ctx, target, source_ir, f"{name}_trunc", input_val, dividend + ) diff --git a/tests/py/dynamo/conversion/test_trunc_aten.py b/tests/py/dynamo/conversion/test_trunc_aten.py new file mode 100644 index 0000000000..979ced17e2 --- /dev/null +++ b/tests/py/dynamo/conversion/test_trunc_aten.py @@ -0,0 +1,52 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestTruncConverter(DispatchTestCase): + @parameterized.expand( + [ + ((10,),), + ((1, 20),), + ((2, 3, 4),), + ((2, 3, 4, 5),), + ] + ) + def test_trunc_float(self, shape): + class Trunc(nn.Module): + def forward(self, input): + return torch.ops.aten.trunc.default(input) + + inputs = [torch.randn(shape)] + self.run_test( + Trunc(), + inputs, + enable_passes=True, + ) + + @parameterized.expand( + [ + ((10,),), + ((1, 20),), + ((2, 3, 4),), + ((2, 3, 4, 5),), + ] + ) + def test_trunc_int(self, shape): + class Trunc(nn.Module): + def forward(self, input): + return torch.ops.aten.trunc.default(input) + + inputs = [torch.randint(-10, 10, shape, dtype=torch.int32)] + self.run_test( + Trunc(), + inputs, + enable_passes=True, + ) + + +if __name__ == "__main__": + run_tests()