Skip to content

Commit 83176fe

Browse files
authored
Dynamo converter cat (#2343)
1 parent 18dcdd0 commit 83176fe

File tree

3 files changed

+53
-0
lines changed

3 files changed

+53
-0
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+18
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,24 @@ def aten_ops_batch_norm(
7070
)
7171

7272

73+
@dynamo_tensorrt_converter(torch.ops.aten.cat.default)
74+
def aten_ops_cat(
75+
ctx: ConversionContext,
76+
target: Target,
77+
args: Tuple[Argument, ...],
78+
kwargs: Dict[str, Argument],
79+
name: str,
80+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
81+
return impl.cat.cat(
82+
ctx,
83+
target,
84+
SourceIR.ATEN,
85+
name,
86+
input=args[0],
87+
dim=args_bounds_check(args, 1, 0),
88+
)
89+
90+
7391
def embedding_param_validator(embedding_node: Node) -> bool:
7492
scale_grad_by_freq = args_bounds_check(embedding_node.args, 3)
7593
sparse = args_bounds_check(embedding_node.args, 4)

py/torch_tensorrt/dynamo/conversion/impl/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
activation,
55
attention,
66
cast,
7+
cat,
78
condition,
89
conv,
910
deconv,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from typing import Dict, Optional, Sequence, Union
2+
3+
import numpy as np
4+
import torch
5+
from torch.fx.node import Target
6+
from torch_tensorrt.dynamo._SourceIR import SourceIR
7+
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
8+
from torch_tensorrt.dynamo.conversion.converter_utils import (
9+
SourceIR,
10+
get_positive_dim,
11+
get_trt_tensor,
12+
)
13+
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
14+
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
15+
16+
17+
def cat(
18+
ctx: ConversionContext,
19+
target: Target,
20+
source_ir: Optional[SourceIR],
21+
name: str,
22+
input: Sequence[Union[TRTTensor, torch.Tensor, np.ndarray]],
23+
dim: int,
24+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
25+
trt_inputs = []
26+
for each_input in input:
27+
if not isinstance(each_input, TRTTensor):
28+
each_input = get_trt_tensor(ctx, each_input, name + "_tensor_{i}")
29+
trt_inputs.append(each_input)
30+
concat_layer = ctx.net.add_concatenation(trt_inputs)
31+
dim = get_positive_dim(dim, len(input[0].shape))
32+
concat_layer.axis = dim
33+
set_layer_name(concat_layer, target, name + "_gather", source_ir)
34+
return concat_layer.get_output(0)

0 commit comments

Comments
 (0)