|
1 |
| -from typing import Optional |
| 1 | +from typing import Optional, Union |
2 | 2 |
|
3 | 3 | import tensorrt as trt
|
4 | 4 | from torch.fx.node import Target
|
5 | 5 | from torch_tensorrt.dynamo._SourceIR import SourceIR
|
| 6 | +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext |
6 | 7 | from torch_tensorrt.dynamo.conversion.converter_utils import (
|
7 | 8 | cast_trt_tensor,
|
| 9 | + flatten_dims, |
8 | 10 | get_axes_for_reduce_op,
|
9 | 11 | )
|
10 | 12 | from torch_tensorrt.fx.converters.converter_utils import (
|
11 | 13 | get_positive_dim,
|
12 | 14 | set_layer_name,
|
13 | 15 | )
|
14 |
| -from torch_tensorrt.fx.types import TRTNetwork, TRTTensor |
| 16 | +from torch_tensorrt.fx.types import TRTTensor |
15 | 17 |
|
16 | 18 | from . import squeeze
|
17 | 19 |
|
18 | 20 |
|
19 | 21 | def argmax(
|
20 |
| - network: TRTNetwork, |
| 22 | + ctx: ConversionContext, |
21 | 23 | target: Target,
|
22 | 24 | source_ir: Optional[SourceIR],
|
23 | 25 | name: str,
|
24 | 26 | input: TRTTensor,
|
25 |
| - dim: int = 0, |
| 27 | + dim: Union[int, None], |
26 | 28 | keep_dim: bool = False,
|
27 | 29 | ) -> TRTTensor:
|
28 | 30 | if not isinstance(input, TRTTensor):
|
29 | 31 | raise RuntimeError(
|
30 | 32 | f"argmax received input {input} that is not part " "of the TensorRT region!"
|
31 | 33 | )
|
| 34 | + |
32 | 35 | if input.dtype == trt.int32:
|
33 |
| - input = cast_trt_tensor(network, input, trt.float32, name) |
34 |
| - if dim < 0: |
| 36 | + input = cast_trt_tensor(ctx, input, trt.float32, name) |
| 37 | + |
| 38 | + # Three different cases here: |
| 39 | + # 1. dim == None, flatten input tensor first, keep_dim will be ignore and the output rank == input rank |
| 40 | + # 2. input rank == 1: TopK layer does not support 1 dimensional topk operation. Broadcast input to rank == 2 |
| 41 | + # 3. normal cases, no additional handlings |
| 42 | + out = input |
| 43 | + |
| 44 | + if dim is None: |
| 45 | + shuffle_layer = ctx.net.add_shuffle(input) |
| 46 | + shuffle_layer.reshape_dims = (*flatten_dims(input, 0, -1), 1) |
| 47 | + set_layer_name(shuffle_layer, target, name + "_flatten") |
| 48 | + out = shuffle_layer.get_output(0) |
| 49 | + elif len(input.shape) == 1: |
| 50 | + shuffle_layer = ctx.net.add_shuffle(input) |
| 51 | + shuffle_layer.reshape_dims = (*input.shape, 1) |
| 52 | + set_layer_name(shuffle_layer, target, name + "_broadcast") |
| 53 | + out = shuffle_layer.get_output(0) |
| 54 | + elif dim < 0: |
35 | 55 | dim = len(tuple(input.shape)) + dim
|
36 |
| - reduce_mask = get_axes_for_reduce_op(get_positive_dim(dim, len(input.shape))) |
37 |
| - topk_layer = network.add_topk(input, trt.TopKOperation.MAX, 1, reduce_mask) |
| 56 | + |
| 57 | + reduce_mask = get_axes_for_reduce_op(0) |
| 58 | + if dim is not None: |
| 59 | + reduce_mask = get_axes_for_reduce_op(get_positive_dim(dim, len(out.shape))) |
| 60 | + |
| 61 | + topk_layer = ctx.net.add_topk(out, trt.TopKOperation.MAX, 1, reduce_mask) |
38 | 62 | set_layer_name(topk_layer, target, name)
|
39 | 63 |
|
40 | 64 | out = topk_layer.get_output(1)
|
41 | 65 |
|
42 |
| - if not keep_dim: |
| 66 | + if dim is None: |
| 67 | + out_shuffle_layer = ctx.net.add_shuffle(out) |
| 68 | + out_shuffle_layer.reshape_dims = (1,) * len(input.shape) if keep_dim else () |
| 69 | + set_layer_name(out_shuffle_layer, target, name + "_broadcast") |
| 70 | + out = out_shuffle_layer.get_output(0) |
| 71 | + elif len(input.shape) == 1: |
43 | 72 | out = squeeze.squeeze(
|
44 |
| - network, target, SourceIR.ATEN, name + "_squeeze", out, dim |
| 73 | + ctx, |
| 74 | + target, |
| 75 | + SourceIR.ATEN, |
| 76 | + name + "_squeeze", |
| 77 | + out, |
| 78 | + 1 if keep_dim else [0, 1], |
45 | 79 | )
|
| 80 | + elif not keep_dim: |
| 81 | + out = squeeze.squeeze(ctx, target, SourceIR.ATEN, name + "_squeeze", out, dim) |
46 | 82 |
|
47 | 83 | return out
|
0 commit comments