Skip to content

Commit ffe53e0

Browse files
bowang007gs-olive
authored andcommitted
handle edge cases
1 parent d6a14d9 commit ffe53e0

File tree

2 files changed

+86
-10
lines changed

2 files changed

+86
-10
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,83 @@
1-
from typing import Optional
1+
from typing import Optional, Union
22

33
import tensorrt as trt
44
from torch.fx.node import Target
55
from torch_tensorrt.dynamo._SourceIR import SourceIR
6+
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
67
from torch_tensorrt.dynamo.conversion.converter_utils import (
78
cast_trt_tensor,
9+
flatten_dims,
810
get_axes_for_reduce_op,
911
)
1012
from torch_tensorrt.fx.converters.converter_utils import (
1113
get_positive_dim,
1214
set_layer_name,
1315
)
14-
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
16+
from torch_tensorrt.fx.types import TRTTensor
1517

1618
from . import squeeze
1719

1820

1921
def argmax(
20-
network: TRTNetwork,
22+
ctx: ConversionContext,
2123
target: Target,
2224
source_ir: Optional[SourceIR],
2325
name: str,
2426
input: TRTTensor,
25-
dim: int = 0,
27+
dim: Union[int, None],
2628
keep_dim: bool = False,
2729
) -> TRTTensor:
2830
if not isinstance(input, TRTTensor):
2931
raise RuntimeError(
3032
f"argmax received input {input} that is not part " "of the TensorRT region!"
3133
)
34+
3235
if input.dtype == trt.int32:
33-
input = cast_trt_tensor(network, input, trt.float32, name)
34-
if dim < 0:
36+
input = cast_trt_tensor(ctx, input, trt.float32, name)
37+
38+
# Three different cases here:
39+
# 1. dim == None, flatten input tensor first, keep_dim will be ignore and the output rank == input rank
40+
# 2. input rank == 1: TopK layer does not support 1 dimensional topk operation. Broadcast input to rank == 2
41+
# 3. normal cases, no additional handlings
42+
out = input
43+
44+
if dim is None:
45+
shuffle_layer = ctx.net.add_shuffle(input)
46+
shuffle_layer.reshape_dims = (*flatten_dims(input, 0, -1), 1)
47+
set_layer_name(shuffle_layer, target, name + "_flatten")
48+
out = shuffle_layer.get_output(0)
49+
elif len(input.shape) == 1:
50+
shuffle_layer = ctx.net.add_shuffle(input)
51+
shuffle_layer.reshape_dims = (*input.shape, 1)
52+
set_layer_name(shuffle_layer, target, name + "_broadcast")
53+
out = shuffle_layer.get_output(0)
54+
elif dim < 0:
3555
dim = len(tuple(input.shape)) + dim
36-
reduce_mask = get_axes_for_reduce_op(get_positive_dim(dim, len(input.shape)))
37-
topk_layer = network.add_topk(input, trt.TopKOperation.MAX, 1, reduce_mask)
56+
57+
reduce_mask = get_axes_for_reduce_op(0)
58+
if dim is not None:
59+
reduce_mask = get_axes_for_reduce_op(get_positive_dim(dim, len(out.shape)))
60+
61+
topk_layer = ctx.net.add_topk(out, trt.TopKOperation.MAX, 1, reduce_mask)
3862
set_layer_name(topk_layer, target, name)
3963

4064
out = topk_layer.get_output(1)
4165

42-
if not keep_dim:
66+
if dim is None:
67+
out_shuffle_layer = ctx.net.add_shuffle(out)
68+
out_shuffle_layer.reshape_dims = (1,) * len(input.shape) if keep_dim else ()
69+
set_layer_name(out_shuffle_layer, target, name + "_broadcast")
70+
out = out_shuffle_layer.get_output(0)
71+
elif len(input.shape) == 1:
4372
out = squeeze.squeeze(
44-
network, target, SourceIR.ATEN, name + "_squeeze", out, dim
73+
ctx,
74+
target,
75+
SourceIR.ATEN,
76+
name + "_squeeze",
77+
out,
78+
1 if keep_dim else [0, 1],
4579
)
80+
elif not keep_dim:
81+
out = squeeze.squeeze(ctx, target, SourceIR.ATEN, name + "_squeeze", out, dim)
4682

4783
return out
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
6+
from .harness import DispatchTestCase
7+
8+
9+
class TestArgmaxConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
# input dimension == 1
13+
("dim_1_keep_dim_true", (3,), 0, True),
14+
("dim_1_keep_dim_true", (3,), 0, False),
15+
# dim == None
16+
("dim_none", (3,), None, True),
17+
("dim_none", (3, 3), None, True),
18+
("dim_none", (3, 3, 3), None, False),
19+
# # common cases
20+
("dim_1_keep_dim_true", (3, 3), 1, True),
21+
("dim_1_keep_dim_false", (3, 3), 1, False),
22+
("dim_0_keep_dim_true", (4, 4, 4), 0, True),
23+
("dim_0_keep_dim_false", (4, 4, 4), 0, False),
24+
]
25+
)
26+
def test_argmax(self, _, input_shape, dim, keep_dim):
27+
class ArgMax(nn.Module):
28+
def __init__(self):
29+
super().__init__()
30+
31+
def forward(self, input):
32+
return torch.ops.aten.argmax.default(input, dim, keep_dim)
33+
34+
input = [torch.randn(*input_shape)]
35+
36+
self.run_test(ArgMax(), input)
37+
38+
39+
if __name__ == "__main__":
40+
run_tests()

0 commit comments

Comments
 (0)