diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 3878efc6af..451d218ee7 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -440,6 +440,37 @@ def aten_ops_expand( ) +def amax_param_validator(amax_node: Node) -> bool: + if len(amax_node.args) < 2: + _LOGGER.debug( + f"At least two args input and dim should be provided, but only got {len(amax_node.args)} args." + ) + return False + + return True + + +@dynamo_tensorrt_converter( + torch.ops.aten.amax.default, capability_validator=amax_param_validator +) +def aten_ops_amax( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.reduce.amax( + network, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + args_bounds_check(args, 2, replacement=False), + ) + + @dynamo_tensorrt_converter(torch.ops.aten.exp.default) # type: ignore[misc] def aten_ops_exp( network: TRTNetwork, diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index ed0f1bb843..e33bf09903 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -1,3 +1,4 @@ +import functools import logging import re from typing import List, Optional @@ -7,6 +8,7 @@ from torch.fx.node import Target from torch_tensorrt.fx.converters.converter_utils import ( Frameworks, + get_axes_for_reduce_op, unified_dtype_converter, ) from torch_tensorrt.fx.types import TRTDataType, TRTNetwork, TRTTensor @@ -157,3 +159,8 @@ def broadcastable( if not (a_shape[i] == b_shape[i] or a_shape[i] == 1 or b_shape[i] == 1): return False return True + + +get_axes_for_reduce_op = functools.partial( + get_axes_for_reduce_op, has_implicit_batch_dimension=False +) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py index 8f7ab1badc..6bd315871c 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py @@ -9,6 +9,7 @@ matmul, normalization, permutation, + reduce, select, shape, slice, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/reduce.py b/py/torch_tensorrt/dynamo/conversion/impl/reduce.py new file mode 100644 index 0000000000..53070761dd --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/reduce.py @@ -0,0 +1,35 @@ +from typing import Optional, Tuple, Union + +import tensorrt as trt +from torch.fx.node import Target +from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion.converter_utils import ( + cast_trt_tensor, + get_axes_for_reduce_op, +) +from torch_tensorrt.fx.converters.converter_utils import set_layer_name +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor + + +def amax( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_val: TRTTensor, + dim: Union[int, Tuple[int]], + keepdim: bool = False, +) -> TRTTensor: + if (isinstance(input_val, TRTTensor)) and ( + input_val.dtype == trt.int8 or input_val.dtype == trt.int32 + ): + input_val = cast_trt_tensor(network, input_val, trt.float32, name) + + layer = network.add_reduce( + input_val, + trt.ReduceOperation.MAX, + axes=get_axes_for_reduce_op(dim), + keep_dims=keepdim, + ) + set_layer_name(layer, target, name, source_ir) + return layer.get_output(0) diff --git a/tests/py/dynamo/converters/test_amax_aten.py b/tests/py/dynamo/converters/test_amax_aten.py new file mode 100644 index 0000000000..b6024c83ba --- /dev/null +++ b/tests/py/dynamo/converters/test_amax_aten.py @@ -0,0 +1,93 @@ +import torch +import torch.nn as nn +from harness import DispatchTestCase +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + + +class TestAmaxConverter(DispatchTestCase): + @parameterized.expand( + [ + ((3, 2, 4), 1, True), + ((2, 3, 4, 5), 3, True), + ((2, 3, 4, 5), 2, False), + ((6, 7, 5, 4, 5), 4, False), + ] + ) + def test_amax_dim_int_default(self, input_shape, dim, keep_dims): + class Amax(nn.Module): + def forward(self, x): + return torch.amax(x, dim=dim, keepdim=keep_dims) + + inputs = [torch.randn(*input_shape)] + self.run_test( + Amax(), + inputs, + expected_ops={torch.ops.aten.amax.default}, + ) + + @parameterized.expand( + [ + ((3, 2, 4), [1], True), + ((2, 1, 4, 5), [0, 3], True), + ((2, 3, 4, 5), [0, 1, 2, 3], False), + ((6, 7, 5, 4, 5), [1, 3, 4], False), + ] + ) + def test_amax_dim_tuple_default(self, input_shape, dim, keep_dims): + class Amax(nn.Module): + def forward(self, x): + return torch.amax(x, dim=dim, keepdim=keep_dims) + + inputs = [torch.randn(*input_shape)] + self.run_test( + Amax(), + inputs, + expected_ops={torch.ops.aten.amax.default}, + ) + + @parameterized.expand( + [ + ((3, 2, 4), 1, True, torch.int, 0, 5), + ((2, 3, 4, 5), 3, True, torch.int, -10, 10), + ((2, 3, 4, 5), 2, False, torch.int32, -5, 0), + ((6, 7, 5, 4, 5), 4, False, torch.int32, -5, 5), + ] + ) + def test_amax_dim_int_int(self, input_shape, dim, keep_dims, dtype, low, high): + class Amax(nn.Module): + def forward(self, x): + return torch.amax(x, dim=dim, keepdim=keep_dims) + + inputs = [torch.randint(low, high, input_shape, dtype=dtype)] + self.run_test( + Amax(), + inputs, + expected_ops={torch.ops.aten.amax.default}, + check_dtype=False, + ) + + @parameterized.expand( + [ + ((3, 2, 4), [1], True, torch.int, 0, 5), + ((2, 1, 4, 5), [0, 3], True, torch.int, -10, 10), + ((2, 3, 4, 5), [0, 1, 2, 3], False, torch.int32, -5, 0), + ((6, 7, 5, 4, 5), [1, 3, 4], False, torch.int32, -5, 5), + ] + ) + def test_amax_dim_tuple_int(self, input_shape, dim, keep_dims, dtype, low, high): + class Amax(nn.Module): + def forward(self, x): + return torch.amax(x, dim=dim, keepdim=keep_dims) + + inputs = [torch.randint(low, high, input_shape, dtype=dtype)] + self.run_test( + Amax(), + inputs, + expected_ops={torch.ops.aten.amax.default}, + check_dtype=False, + ) + + +if __name__ == "__main__": + run_tests()