Skip to content

Commit 786581b

Browse files
committed
support argmax converter
1 parent 91fcea4 commit 786581b

File tree

2 files changed

+74
-0
lines changed

2 files changed

+74
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from typing import Optional, cast
2+
3+
import numpy as np
4+
from torch.fx.node import Target
5+
from torch_tensorrt.dynamo._SourceIR import SourceIR
6+
from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape
7+
from torch_tensorrt.fx.converters.converter_utils import (
8+
get_positive_dim,
9+
has_dynamic_shape,
10+
to_numpy,
11+
)
12+
from torch_tensorrt.fx.types import Shape, TRTNetwork, TRTTensor
13+
14+
import tensorrt as trt
15+
16+
17+
def argmax(
18+
network: TRTNetwork,
19+
target: Target,
20+
source_ir: Optional[SourceIR],
21+
name: str,
22+
input: TRTTensor,
23+
dim: int = 0,
24+
keep_dim: bool = False,
25+
) -> TRTTensor:
26+
if not isinstance(input, TRTTensor):
27+
raise RuntimeError(
28+
f"argmax received input {input} that is not part "
29+
"of the TensorRT region!"
30+
)
31+
if dim < 0:
32+
dim = len(tuple(input.shape)) + dim
33+
reduce_mask = 1 << dim
34+
topk_layer = network.add_topk(input, trt.TopKOperation.MAX, 1, reduce_mask)
35+
36+
set_layer_name(topk_layer, target, name)
37+
38+
return topk_layer.get_output(1)
39+
40+
+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
from harness import DispatchTestCase
6+
7+
class TestArgmaxConverter(DispatchTestCase):
8+
@parameterized.expand(
9+
[
10+
("dim_0_keep_dim_false", (3, 4), 0, False)
11+
]
12+
)
13+
14+
def test_argmax(self, _, input_shape, dim, keep_dim):
15+
class ArgMax(nn.Module):
16+
def __init__(self):
17+
super().__init__()
18+
19+
def forward(self, input):
20+
return torch.argmax(input, dim, keep_dim)
21+
22+
23+
input = [torch.randn(*input_shape)]
24+
25+
self.run_test(
26+
ArgMax(),
27+
input,
28+
expected_ops={torch.ops.aten.argmax.default}
29+
)
30+
31+
if __name__ == "__main__":
32+
run_tests()
33+
34+

0 commit comments

Comments
 (0)