Skip to content

Commit 5ec8a2b

Browse files
committed
feat: support argmin aten converter and refactor argmax
1 parent 7f7c907 commit 5ec8a2b

File tree

5 files changed

+98
-8
lines changed

5 files changed

+98
-8
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+21-1
Original file line numberDiff line numberDiff line change
@@ -2284,7 +2284,27 @@ def aten_ops_argmax(
22842284
kwargs: Dict[str, Argument],
22852285
name: str,
22862286
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2287-
return impl.argmax.argmax(
2287+
return impl.topk.argmax(
2288+
ctx,
2289+
target,
2290+
SourceIR.ATEN,
2291+
name,
2292+
input=args[0],
2293+
dim=args_bounds_check(args, 1),
2294+
keep_dim=args_bounds_check(args, 2, False),
2295+
)
2296+
2297+
2298+
@enforce_tensor_types({0: (TRTTensor,)})
2299+
@dynamo_tensorrt_converter(torch.ops.aten.argmin.default)
2300+
def aten_ops_argmin(
2301+
ctx: ConversionContext,
2302+
target: Target,
2303+
args: Tuple[Argument, ...],
2304+
kwargs: Dict[str, Argument],
2305+
name: str,
2306+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2307+
return impl.topk.argmin(
22882308
ctx,
22892309
target,
22902310
SourceIR.ATEN,

py/torch_tensorrt/dynamo/conversion/impl/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from . import (
44
activation,
55
addmm,
6-
argmax,
76
attention,
87
cast,
98
cat,
@@ -26,6 +25,7 @@
2625
slice,
2726
split,
2827
squeeze,
28+
topk,
2929
unary,
3030
unsqueeze,
3131
)

py/torch_tensorrt/dynamo/conversion/impl/argmax.py renamed to py/torch_tensorrt/dynamo/conversion/impl/topk.py

+31-2
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,13 @@
1515
from torch_tensorrt.fx.types import TRTTensor
1616

1717

18-
def argmax(
18+
def argmax_argmin(
1919
ctx: ConversionContext,
2020
target: Target,
2121
source_ir: Optional[SourceIR],
2222
name: str,
2323
input: TRTTensor,
24+
topk_option: trt.TopKOperation,
2425
dim: Optional[int],
2526
keep_dim: bool = False,
2627
) -> TRTTensor:
@@ -49,7 +50,7 @@ def argmax(
4950
get_positive_dim(dim if dim is not None else 0, len(out.shape))
5051
)
5152

52-
topk_layer = ctx.net.add_topk(out, trt.TopKOperation.MAX, 1, reduce_mask)
53+
topk_layer = ctx.net.add_topk(out, topk_option, 1, reduce_mask)
5354
set_layer_name(topk_layer, target, name, source_ir)
5455

5556
out = topk_layer.get_output(1)
@@ -72,3 +73,31 @@ def argmax(
7273
out = impl.squeeze.squeeze(ctx, target, source_ir, f"{name}_squeeze", out, dim)
7374

7475
return out
76+
77+
78+
def argmax(
79+
ctx: ConversionContext,
80+
target: Target,
81+
source_ir: Optional[SourceIR],
82+
name: str,
83+
input: TRTTensor,
84+
dim: Optional[int],
85+
keep_dim: bool = False,
86+
) -> TRTTensor:
87+
return argmax_argmin(
88+
ctx, target, source_ir, name, input, trt.TopKOperation.MAX, dim, keep_dim
89+
)
90+
91+
92+
def argmin(
93+
ctx: ConversionContext,
94+
target: Target,
95+
source_ir: Optional[SourceIR],
96+
name: str,
97+
input: TRTTensor,
98+
dim: Optional[int],
99+
keep_dim: bool = False,
100+
) -> TRTTensor:
101+
return argmax_argmin(
102+
ctx, target, source_ir, name, input, trt.TopKOperation.MIN, dim, keep_dim
103+
)

tests/py/dynamo/conversion/test_argmax_aten.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@ class TestArgmaxConverter(DispatchTestCase):
1111
[
1212
# input dimension == 1
1313
("dim_1_keep_dim_true", (3,), 0, True),
14-
("dim_1_keep_dim_true", (3,), 0, False),
14+
("dim_1_keep_dim_false", (3,), 0, False),
1515
# dim == None
16-
("dim_none", (3,), None, True),
17-
("dim_none", (3, 3), None, True),
18-
("dim_none", (3, 3, 3), None, False),
16+
("dim_1_none_true", (3,), None, True),
17+
("dim_2_none_true", (3, 3), None, True),
18+
("dim_3_none_false", (3, 3, 3), None, False),
1919
# # common cases
2020
("dim_1_keep_dim_true", (3, 3), 1, True),
2121
("dim_1_keep_dim_false", (3, 3), 1, False),
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
6+
from .harness import DispatchTestCase
7+
8+
9+
class TestArgminConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
# input dimension == 1
13+
("dim_1_keep_dim_true", (3,), 0, True),
14+
("dim_1_keep_dim_false", (3,), 0, False),
15+
# dim == None
16+
("dim_1_none_true", (3,), None, True),
17+
("dim_2_none_true", (3, 3), None, True),
18+
("dim_3_none_false", (3, 3, 3), None, False),
19+
# # common cases
20+
("dim_1_keep_dim_true", (3, 3), 1, True),
21+
("dim_1_keep_dim_false", (3, 3), 1, False),
22+
("dim_0_keep_dim_true", (4, 4, 4), 0, True),
23+
("dim_0_keep_dim_false", (4, 4, 4), 0, False),
24+
("dim_negative_keep_dim_true", (1, 2, 3), -1, True),
25+
]
26+
)
27+
def test_argmin(self, _, input_shape, dim, keep_dim):
28+
class ArgMin(nn.Module):
29+
def __init__(self):
30+
super().__init__()
31+
32+
def forward(self, input):
33+
return torch.ops.aten.argmin.default(input, dim, keep_dim)
34+
35+
input = [torch.randn(*input_shape)]
36+
37+
self.run_test(ArgMin(), input)
38+
39+
40+
if __name__ == "__main__":
41+
run_tests()

0 commit comments

Comments
 (0)