From 5ec8a2b3afefee95ba0b27031014bfecb0212217 Mon Sep 17 00:00:00 2001 From: Bo Wang Date: Tue, 5 Dec 2023 16:18:54 -0800 Subject: [PATCH] feat: support argmin aten converter and refactor argmax --- .../dynamo/conversion/aten_ops_converters.py | 22 +++++++++- .../dynamo/conversion/impl/__init__.py | 2 +- .../conversion/impl/{argmax.py => topk.py} | 33 ++++++++++++++- .../py/dynamo/conversion/test_argmax_aten.py | 8 ++-- .../py/dynamo/conversion/test_argmin_aten.py | 41 +++++++++++++++++++ 5 files changed, 98 insertions(+), 8 deletions(-) rename py/torch_tensorrt/dynamo/conversion/impl/{argmax.py => topk.py} (76%) create mode 100644 tests/py/dynamo/conversion/test_argmin_aten.py diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index ede6e5e6a9..8eb07c07a7 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2284,7 +2284,27 @@ def aten_ops_argmax( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return impl.argmax.argmax( + return impl.topk.argmax( + ctx, + target, + SourceIR.ATEN, + name, + input=args[0], + dim=args_bounds_check(args, 1), + keep_dim=args_bounds_check(args, 2, False), + ) + + +@enforce_tensor_types({0: (TRTTensor,)}) +@dynamo_tensorrt_converter(torch.ops.aten.argmin.default) +def aten_ops_argmin( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.topk.argmin( ctx, target, SourceIR.ATEN, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py index 62bf556beb..5bace705cb 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py @@ -3,7 +3,6 @@ from . import ( activation, addmm, - argmax, attention, cast, cat, @@ -26,6 +25,7 @@ slice, split, squeeze, + topk, unary, unsqueeze, ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/argmax.py b/py/torch_tensorrt/dynamo/conversion/impl/topk.py similarity index 76% rename from py/torch_tensorrt/dynamo/conversion/impl/argmax.py rename to py/torch_tensorrt/dynamo/conversion/impl/topk.py index f45aec0be5..a9e11cc537 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/argmax.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/topk.py @@ -15,12 +15,13 @@ from torch_tensorrt.fx.types import TRTTensor -def argmax( +def argmax_argmin( ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, input: TRTTensor, + topk_option: trt.TopKOperation, dim: Optional[int], keep_dim: bool = False, ) -> TRTTensor: @@ -49,7 +50,7 @@ def argmax( get_positive_dim(dim if dim is not None else 0, len(out.shape)) ) - topk_layer = ctx.net.add_topk(out, trt.TopKOperation.MAX, 1, reduce_mask) + topk_layer = ctx.net.add_topk(out, topk_option, 1, reduce_mask) set_layer_name(topk_layer, target, name, source_ir) out = topk_layer.get_output(1) @@ -72,3 +73,31 @@ def argmax( out = impl.squeeze.squeeze(ctx, target, source_ir, f"{name}_squeeze", out, dim) return out + + +def argmax( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + dim: Optional[int], + keep_dim: bool = False, +) -> TRTTensor: + return argmax_argmin( + ctx, target, source_ir, name, input, trt.TopKOperation.MAX, dim, keep_dim + ) + + +def argmin( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + dim: Optional[int], + keep_dim: bool = False, +) -> TRTTensor: + return argmax_argmin( + ctx, target, source_ir, name, input, trt.TopKOperation.MIN, dim, keep_dim + ) diff --git a/tests/py/dynamo/conversion/test_argmax_aten.py b/tests/py/dynamo/conversion/test_argmax_aten.py index bf469d0901..a3f9f67b95 100644 --- a/tests/py/dynamo/conversion/test_argmax_aten.py +++ b/tests/py/dynamo/conversion/test_argmax_aten.py @@ -11,11 +11,11 @@ class TestArgmaxConverter(DispatchTestCase): [ # input dimension == 1 ("dim_1_keep_dim_true", (3,), 0, True), - ("dim_1_keep_dim_true", (3,), 0, False), + ("dim_1_keep_dim_false", (3,), 0, False), # dim == None - ("dim_none", (3,), None, True), - ("dim_none", (3, 3), None, True), - ("dim_none", (3, 3, 3), None, False), + ("dim_1_none_true", (3,), None, True), + ("dim_2_none_true", (3, 3), None, True), + ("dim_3_none_false", (3, 3, 3), None, False), # # common cases ("dim_1_keep_dim_true", (3, 3), 1, True), ("dim_1_keep_dim_false", (3, 3), 1, False), diff --git a/tests/py/dynamo/conversion/test_argmin_aten.py b/tests/py/dynamo/conversion/test_argmin_aten.py new file mode 100644 index 0000000000..f06284f394 --- /dev/null +++ b/tests/py/dynamo/conversion/test_argmin_aten.py @@ -0,0 +1,41 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestArgminConverter(DispatchTestCase): + @parameterized.expand( + [ + # input dimension == 1 + ("dim_1_keep_dim_true", (3,), 0, True), + ("dim_1_keep_dim_false", (3,), 0, False), + # dim == None + ("dim_1_none_true", (3,), None, True), + ("dim_2_none_true", (3, 3), None, True), + ("dim_3_none_false", (3, 3, 3), None, False), + # # common cases + ("dim_1_keep_dim_true", (3, 3), 1, True), + ("dim_1_keep_dim_false", (3, 3), 1, False), + ("dim_0_keep_dim_true", (4, 4, 4), 0, True), + ("dim_0_keep_dim_false", (4, 4, 4), 0, False), + ("dim_negative_keep_dim_true", (1, 2, 3), -1, True), + ] + ) + def test_argmin(self, _, input_shape, dim, keep_dim): + class ArgMin(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return torch.ops.aten.argmin.default(input, dim, keep_dim) + + input = [torch.randn(*input_shape)] + + self.run_test(ArgMin(), input) + + +if __name__ == "__main__": + run_tests()