Skip to content

Commit 9ca9577

Browse files
committed
update converter and test case
1 parent 786581b commit 9ca9577

File tree

3 files changed

+30
-31
lines changed

3 files changed

+30
-31
lines changed

py/torch_tensorrt/dynamo/conversion/impl/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from . import (
44
activation,
5+
argmax,
56
cast,
67
condition,
78
elementwise,
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,13 @@
1-
from typing import Optional, cast
1+
from typing import Optional
22

3-
import numpy as np
3+
import tensorrt as trt
44
from torch.fx.node import Target
55
from torch_tensorrt.dynamo._SourceIR import SourceIR
6-
from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape
7-
from torch_tensorrt.fx.converters.converter_utils import (
8-
get_positive_dim,
9-
has_dynamic_shape,
10-
to_numpy,
11-
)
12-
from torch_tensorrt.fx.types import Shape, TRTNetwork, TRTTensor
6+
from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor
7+
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
8+
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
139

14-
import tensorrt as trt
10+
from . import squeeze
1511

1612

1713
def argmax(
@@ -25,16 +21,21 @@ def argmax(
2521
) -> TRTTensor:
2622
if not isinstance(input, TRTTensor):
2723
raise RuntimeError(
28-
f"argmax received input {input} that is not part "
29-
"of the TensorRT region!"
24+
f"argmax received input {input} that is not part " "of the TensorRT region!"
3025
)
26+
if input.dtype == trt.int32:
27+
input = cast_trt_tensor(network, input, trt.float32, name)
3128
if dim < 0:
3229
dim = len(tuple(input.shape)) + dim
3330
reduce_mask = 1 << dim
3431
topk_layer = network.add_topk(input, trt.TopKOperation.MAX, 1, reduce_mask)
35-
3632
set_layer_name(topk_layer, target, name)
3733

38-
return topk_layer.get_output(1)
39-
40-
34+
out = topk_layer.get_output(1)
35+
36+
if not keep_dim:
37+
out = squeeze.squeeze(
38+
network, target, SourceIR.ATEN, name + "_squeeze", out, dim
39+
)
40+
41+
return out
+12-15
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,31 @@
11
import torch
22
import torch.nn as nn
3+
from harness import DispatchTestCase
34
from parameterized import parameterized
45
from torch.testing._internal.common_utils import run_tests
5-
from harness import DispatchTestCase
6+
67

78
class TestArgmaxConverter(DispatchTestCase):
89
@parameterized.expand(
9-
[
10-
("dim_0_keep_dim_false", (3, 4), 0, False)
11-
]
10+
[
11+
("dim_1_keep_dim_true", (3, 3), 1, True),
12+
("dim_1_keep_dim_false", (3, 3), 1, False),
13+
("dim_0_keep_dim_true", (4, 4), 0, True),
14+
("dim_0_keep_dim_false", (4, 4), 0, False),
15+
]
1216
)
13-
1417
def test_argmax(self, _, input_shape, dim, keep_dim):
1518
class ArgMax(nn.Module):
1619
def __init__(self):
1720
super().__init__()
1821

19-
def forward(self, input):
22+
def forward(self, input):
2023
return torch.argmax(input, dim, keep_dim)
21-
2224

2325
input = [torch.randn(*input_shape)]
2426

25-
self.run_test(
26-
ArgMax(),
27-
input,
28-
expected_ops={torch.ops.aten.argmax.default}
29-
)
30-
31-
if __name__ == "__main__":
32-
run_tests()
27+
self.run_test(ArgMax(), input, expected_ops={torch.ops.aten.argmax.default})
3328

3429

30+
if __name__ == "__main__":
31+
run_tests()

0 commit comments

Comments
 (0)