Skip to content

Commit 1f9daf5

Browse files
committed
feat: support argmin aten converter and refactor argmax
1 parent 4f8eb56 commit 1f9daf5

File tree

5 files changed

+98
-8
lines changed

5 files changed

+98
-8
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+21-1
Original file line numberDiff line numberDiff line change
@@ -2267,7 +2267,27 @@ def aten_ops_argmax(
22672267
kwargs: Dict[str, Argument],
22682268
name: str,
22692269
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2270-
return impl.argmax.argmax(
2270+
return impl.topk.argmax(
2271+
ctx,
2272+
target,
2273+
SourceIR.ATEN,
2274+
name,
2275+
input=args[0],
2276+
dim=args_bounds_check(args, 1),
2277+
keep_dim=args_bounds_check(args, 2, False),
2278+
)
2279+
2280+
2281+
@enforce_tensor_types({0: (TRTTensor,)})
2282+
@dynamo_tensorrt_converter(torch.ops.aten.argmin.default)
2283+
def aten_ops_argmin(
2284+
ctx: ConversionContext,
2285+
target: Target,
2286+
args: Tuple[Argument, ...],
2287+
kwargs: Dict[str, Argument],
2288+
name: str,
2289+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2290+
return impl.topk.argmin(
22712291
ctx,
22722292
target,
22732293
SourceIR.ATEN,

py/torch_tensorrt/dynamo/conversion/impl/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from . import (
44
activation,
55
addmm,
6-
argmax,
76
attention,
87
cast,
98
cat,
@@ -25,6 +24,7 @@
2524
slice,
2625
split,
2726
squeeze,
27+
topk,
2828
unary,
2929
unsqueeze,
3030
)

py/torch_tensorrt/dynamo/conversion/impl/argmax.py renamed to py/torch_tensorrt/dynamo/conversion/impl/topk.py

+31-2
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,13 @@
1515
from torch_tensorrt.fx.types import TRTTensor
1616

1717

18-
def argmax(
18+
def argmax_argmin(
1919
ctx: ConversionContext,
2020
target: Target,
2121
source_ir: Optional[SourceIR],
2222
name: str,
2323
input: TRTTensor,
24+
topk_option: trt.TopKOperation,
2425
dim: Optional[int],
2526
keep_dim: bool = False,
2627
) -> TRTTensor:
@@ -49,7 +50,7 @@ def argmax(
4950
get_positive_dim(dim if dim is not None else 0, len(out.shape))
5051
)
5152

52-
topk_layer = ctx.net.add_topk(out, trt.TopKOperation.MAX, 1, reduce_mask)
53+
topk_layer = ctx.net.add_topk(out, topk_option, 1, reduce_mask)
5354
set_layer_name(topk_layer, target, name, source_ir)
5455

5556
out = topk_layer.get_output(1)
@@ -72,3 +73,31 @@ def argmax(
7273
out = impl.squeeze.squeeze(ctx, target, source_ir, f"{name}_squeeze", out, dim)
7374

7475
return out
76+
77+
78+
def argmax(
79+
ctx: ConversionContext,
80+
target: Target,
81+
source_ir: Optional[SourceIR],
82+
name: str,
83+
input: TRTTensor,
84+
dim: Optional[int],
85+
keep_dim: bool = False,
86+
) -> TRTTensor:
87+
return argmax_argmin(
88+
ctx, target, source_ir, name, input, trt.TopKOperation.MAX, dim, keep_dim
89+
)
90+
91+
92+
def argmin(
93+
ctx: ConversionContext,
94+
target: Target,
95+
source_ir: Optional[SourceIR],
96+
name: str,
97+
input: TRTTensor,
98+
dim: Optional[int],
99+
keep_dim: bool = False,
100+
) -> TRTTensor:
101+
return argmax_argmin(
102+
ctx, target, source_ir, name, input, trt.TopKOperation.MIN, dim, keep_dim
103+
)

tests/py/dynamo/conversion/test_argmax_aten.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@ class TestArgmaxConverter(DispatchTestCase):
1111
[
1212
# input dimension == 1
1313
("dim_1_keep_dim_true", (3,), 0, True),
14-
("dim_1_keep_dim_true", (3,), 0, False),
14+
("dim_1_keep_dim_false", (3,), 0, False),
1515
# dim == None
16-
("dim_none", (3,), None, True),
17-
("dim_none", (3, 3), None, True),
18-
("dim_none", (3, 3, 3), None, False),
16+
("dim_1_none_true", (3,), None, True),
17+
("dim_2_none_true", (3, 3), None, True),
18+
("dim_3_none_false", (3, 3, 3), None, False),
1919
# # common cases
2020
("dim_1_keep_dim_true", (3, 3), 1, True),
2121
("dim_1_keep_dim_false", (3, 3), 1, False),
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
6+
from .harness import DispatchTestCase
7+
8+
9+
class TestArgminConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
# input dimension == 1
13+
("dim_1_keep_dim_true", (3,), 0, True),
14+
("dim_1_keep_dim_false", (3,), 0, False),
15+
# dim == None
16+
("dim_1_none_true", (3,), None, True),
17+
("dim_2_none_true", (3, 3), None, True),
18+
("dim_3_none_false", (3, 3, 3), None, False),
19+
# # common cases
20+
("dim_1_keep_dim_true", (3, 3), 1, True),
21+
("dim_1_keep_dim_false", (3, 3), 1, False),
22+
("dim_0_keep_dim_true", (4, 4, 4), 0, True),
23+
("dim_0_keep_dim_false", (4, 4, 4), 0, False),
24+
("dim_negative_keep_dim_true", (1, 2, 3), -1, True),
25+
]
26+
)
27+
def test_argmin(self, _, input_shape, dim, keep_dim):
28+
class ArgMin(nn.Module):
29+
def __init__(self):
30+
super().__init__()
31+
32+
def forward(self, input):
33+
return torch.ops.aten.argmin.default(input, dim, keep_dim)
34+
35+
input = [torch.randn(*input_shape)]
36+
37+
self.run_test(ArgMin(), input)
38+
39+
40+
if __name__ == "__main__":
41+
run_tests()

0 commit comments

Comments
 (0)