Skip to content

Commit 71ca151

Browse files
committed
Index converter
1 parent b9c8578 commit 71ca151

File tree

2 files changed

+61
-1
lines changed

2 files changed

+61
-1
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+18
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,24 @@ def aten_ops_gelu(
169169
)
170170

171171

172+
@dynamo_tensorrt_converter(torch.ops.aten.index.Tensor)
173+
def aten_ops_index(
174+
network: TRTNetwork,
175+
target: Target,
176+
args: Tuple[Argument, ...],
177+
kwargs: Dict[str, Argument],
178+
name: str,
179+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
180+
return impl.select.gelu(
181+
network,
182+
target,
183+
SourceIR.ATEN,
184+
name,
185+
args[0],
186+
args[1],
187+
)
188+
189+
172190
@dynamo_tensorrt_converter(torch.ops.aten.matmul) # type: ignore[misc]
173191
@dynamo_tensorrt_converter(torch.ops.aten.mm.default) # type: ignore[misc]
174192
def aten_ops_matmul(

py/torch_tensorrt/dynamo/conversion/impl/select.py

+43-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import Optional, cast
1+
from typing import Optional, cast, Union, Sequence
2+
import tensorrt as trt
23

34
import numpy as np
45
from torch.fx.node import Target
@@ -10,6 +11,7 @@
1011
to_numpy,
1112
)
1213
from torch_tensorrt.fx.types import Shape, TRTNetwork, TRTTensor
14+
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
1315

1416

1517
def select(
@@ -61,3 +63,43 @@ def select(
6163
if len(out.shape) != 1:
6264
layer = network.add_shuffle(out)
6365
return layer.get_output(0)
66+
67+
68+
def select(
69+
network: TRTNetwork,
70+
target: Target,
71+
source_ir: Optional[SourceIR],
72+
name: str,
73+
input: TRTTensor,
74+
index: Union[TRTTensor, Sequence[TRTTensor]]
75+
) -> TRTTensor:
76+
adv_indx_indices = []
77+
tensor_indices = []
78+
79+
for i in len(index):
80+
ind = index[i]
81+
#FIXME: check if the datatype for the indices needs to be casted to INT32
82+
#TRTInterpretor should take care
83+
adv_indx_indices.append(i)
84+
tensor_indices.append(ind)
85+
86+
if not tensor_indices:
87+
identity_layer = network.add_identity(input)
88+
identity_layer.set_output_type(0, trt.int32)
89+
set_layer_name(identity_layer, target, name + "_index_identity", source_ir)
90+
return identity_layer.get_output(0)
91+
elif (len(tensor_indices) == 1):
92+
indices_tensor = tensor_indices[0]
93+
gather_layer = network.add_gather(input, indices_tensor, adv_indx_indices[0])
94+
set_layer_name(gather_layer, target, name + "_index_gather", source_ir)
95+
return gather_layer.get_output(0)
96+
else:
97+
input_shape = input.shape
98+
rank = len(input_shape)
99+
adv_indx_count = len(adv_indx_indices)
100+
input_shape_layer = network.add_shape(input)
101+
set_layer_name(input_shape_layer, target, name + "_index_shape", source_ir)
102+
input_shape_tensor = input_shape_layer.get_output(0)
103+
return input_shape_tensor.get_output(0)
104+
105+

0 commit comments

Comments
 (0)