Skip to content

Commit 4058533

Browse files
committed
Index converter
1 parent bb5bf00 commit 4058533

File tree

2 files changed

+251
-1
lines changed

2 files changed

+251
-1
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+18
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,24 @@ def aten_ops_sigmoid(
186186
)
187187

188188

189+
@dynamo_tensorrt_converter(torch.ops.aten.index.Tensor)
190+
def aten_ops_index(
191+
network: TRTNetwork,
192+
target: Target,
193+
args: Tuple[Argument, ...],
194+
kwargs: Dict[str, Argument],
195+
name: str,
196+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
197+
return impl.select.index(
198+
network,
199+
target,
200+
SourceIR.ATEN,
201+
name,
202+
args[0],
203+
args[1],
204+
)
205+
206+
189207
@dynamo_tensorrt_converter(torch.ops.aten.tanh.default)
190208
def aten_ops_tanh(
191209
network: TRTNetwork,

py/torch_tensorrt/dynamo/conversion/impl/select.py

+233-1
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
1-
from typing import Optional, cast
1+
from typing import Optional, cast, Union, Sequence
2+
import tensorrt as trt
23

34
import numpy as np
45
from torch.fx.node import Target
56
from torch_tensorrt.dynamo._SourceIR import SourceIR
67
from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape
8+
from torch_tensorrt.dynamo.conversion.impl.elementwise import convert_binary_elementwise
79
from torch_tensorrt.fx.converters.converter_utils import (
810
get_positive_dim,
11+
get_trt_tensor,
912
has_dynamic_shape,
1013
to_numpy,
1114
)
1215
from torch_tensorrt.fx.types import Shape, TRTNetwork, TRTTensor
16+
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
1317

1418

1519
def select(
@@ -61,3 +65,231 @@ def select(
6165
if len(out.shape) != 1:
6266
layer = network.add_shuffle(out)
6367
return layer.get_output(0)
68+
69+
70+
def index(
71+
network: TRTNetwork,
72+
target: Target,
73+
source_ir: Optional[SourceIR],
74+
name: str,
75+
input: TRTTensor,
76+
index: Union[TRTTensor, Sequence[TRTTensor]]
77+
) -> TRTTensor:
78+
adv_indx_indices = []
79+
tensor_indices = []
80+
81+
for i in len(index):
82+
ind = index[i]
83+
#FIXME: check if the datatype for the indices needs to be casted to INT32
84+
#TRTInterpretor should take care
85+
adv_indx_indices.append(i)
86+
tensor_indices.append(ind)
87+
88+
if not tensor_indices:
89+
identity_layer = network.add_identity(input)
90+
identity_layer.set_output_type(0, trt.int32)
91+
set_layer_name(identity_layer, target, name + "_index_identity", source_ir)
92+
return identity_layer.get_output(0)
93+
elif (len(tensor_indices) == 1):
94+
indices_tensor = tensor_indices[0]
95+
gather_layer = network.add_gather(input, indices_tensor, adv_indx_indices[0])
96+
set_layer_name(gather_layer, target, name + "_index_gather", source_ir)
97+
return gather_layer.get_output(0)
98+
else:
99+
input_shape = input.shape
100+
rank = len(input_shape)
101+
adv_indx_count = len(adv_indx_indices)
102+
input_shape_layer = network.add_shape(input)
103+
set_layer_name(input_shape_layer, target, name + "_index_shape", source_ir)
104+
input_shape_tensor = input_shape_layer.get_output(0)
105+
dim_tensor_list = []
106+
for i in range(rank):
107+
#check this
108+
dim_tensor_layer = network.add_gather(input_shape_tensor, i ,0)
109+
set_layer_name(input_shape_layer, target, name + "_index_gather_rank", source_ir)
110+
dim_tensor = dim_tensor_layer.get_output(0)
111+
dim_tensor_list.append(dim_tensor)
112+
113+
#for cases like
114+
#t: [x_1, y_1, y_2, ..., x_m, ..., y_n] -> t: [x_1, x_2, ..., x_m, y_1, y_2, ..., y_n],
115+
#where t is a tensor of rank m+n, {x_i} are axes where tensor index is provided, and {y_i} are axes
116+
#for ":"
117+
#Examples: x.shape = (10,20,30,40,50)
118+
#ind_1, ind_2 broadcasted to (2,3,4)
119+
#x[:, ind_1, ind_2] = 10, 2, 3, 4, 40, 50
120+
#x[:,ind_1, :, ind_2] = 2, 3, 4, 10, 30, 50
121+
transpose_layer = network.add_shuffle(input)
122+
new_order = []
123+
for i in range(adv_indx_count):
124+
new_order.append(adv_indx_indices[i])
125+
for i in range(rank):
126+
if i not in adv_indx_indices:
127+
new_order.append(i)
128+
129+
permute_order = trt.Permutation()
130+
permute_order(new_order)
131+
transpose_layer.set_second_transpose(permute_order)
132+
set_layer_name(transpose_layer, target, name + "_index_transpose", source_ir)
133+
transpose_tensor = transpose_layer.get_output(0)
134+
135+
#Flatten [x_1, x_2,.......x_m, y_1, y_2,.....y_m]
136+
transpose_tensor_shape = network.add_shape(transpose_tensor)
137+
d0 = 1
138+
d0 = get_trt_tensor(network, d0, "d0_initial")
139+
for i in range(adv_indx_count):
140+
dim_tensor_layer = network.add_gather(transpose_tensor_shape, i, 0)
141+
set_layer_name(dim_tensor_layer, target, name + "_index_gather_concatOne", source_ir)
142+
d0_gather = gather_layer.get_output(0)
143+
mult_d0 = convert_binary_elementwise(
144+
network,
145+
target,
146+
source_ir,
147+
name + "index_concatOne_shape",
148+
trt.ElementWisePROD,
149+
mult_d0,
150+
d0_gather,
151+
)
152+
153+
d1 = 1
154+
d1 = get_trt_tensor(network, d0, "d0_initial")
155+
for i in range(adv_indx_count, rank):
156+
dim_tensor_layer = network.add_gather(transpose_tensor_shape, i, 0)
157+
set_layer_name(dim_tensor_layer, target, name + "_index_gather_concatTwo", source_ir)
158+
d1_gather = gather_layer.get_output(0)
159+
mult_d1 = convert_binary_elementwise(
160+
network,
161+
target,
162+
source_ir,
163+
name + "index_concatTwo_shape",
164+
trt.ElementWisePROD,
165+
mult_d1,
166+
d1_gather,
167+
)
168+
concat_tensor_layer = network.add_concatenation([mult_d0, mult_d1])
169+
set_layer_name(concat_tensor_layer, target, name + "_index_Concat", source_ir)
170+
concat_tensor = concat_tensor_layer.get_output(0)
171+
172+
reshape_layer = network.add_shuffle(transpose_tensor)
173+
#check this
174+
reshape_layer.set_input(1, concat_tensor)
175+
flatten_tensor = reshape_layer.get_output(0)
176+
177+
#tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)), ind_i is input indices[i], x_j is the
178+
#// j dimension of input x.
179+
multiplier = get_trt_tensor(network, dim_tensor_list[adv_indx_indices[adv_indx_count - 1]], "dim_last")
180+
cum_adv_index = tensor_indices[adv_indx_count - 1]
181+
for i in range(adv_indx_count-2, 0):
182+
adv_index = convert_binary_elementwise(
183+
network,
184+
target,
185+
source_ir,
186+
name + "index_intermediate",
187+
trt.ElementWisePROD,
188+
multiplier,
189+
tensor_indices[i],
190+
)
191+
cum_adv_index = convert_binary_elementwise(
192+
network,
193+
target,
194+
source_ir,
195+
name + "index_sum_intermediate",
196+
trt.ElementWiseSUM,
197+
cum_adv_index,
198+
adv_index,
199+
)
200+
multiplier = convert_binary_elementwise(
201+
network,
202+
target,
203+
source_ir,
204+
name + "index_intermediate",
205+
trt.ElementWisePROD,
206+
multiplier,
207+
dim_tensor_list[adv_indx_count[i]],
208+
)
209+
210+
gather_layer_element = network.add_gather(flatten_tensor, cum_adv_index, 0)
211+
set_layer_name(gather_layer_element, target, name + "_index_gather_element", source_ir)
212+
gather_out = gather_layer.get_output(0)
213+
214+
cum_adv_index_shape_tensor = cum_adv_index.add_shape(cum_adv_index_shape_tensor)
215+
#check if all advanced indices are consecutive
216+
concat_tensor_reshape = []
217+
if(adv_indx_count == adv_indx_indices[adv_indx_count - 1] - adv_indx_indices[0] + 1):
218+
#concat_tensor_reshape_initial = -1
219+
#concat_tensor_reshape_initial_tensor = get_trt_tensor(network, concat_tensor_reshape_initial, "concat_tensor_reshape_initial")
220+
concat_tensor_reshape.append(-1)
221+
for i in range(0, rank):
222+
if i not in adv_indx_indices:
223+
curr_dim = dim_tensor_list[i]
224+
concat_tensor_reshape.append(curr_dim)
225+
226+
concat_tensor_layer = network.add_concatenation(concat_tensor_reshape)
227+
set_layer_name(concat_tensor_layer, target, name + "_index_Concat_reshape", source_ir)
228+
concat_tensor = concat_tensor_layer.get_output(0)
229+
230+
regular_index_shuffle_layer = network.add_shuffle(gather_out)
231+
set_layer_name(regular_index_shuffle_layer, target, name + "_index_regular_index", source_ir)
232+
unfold_tensor = regular_index_shuffle_layer.get_output(0)
233+
234+
transpose_advanced_shuffle_layer = network.add_shuffle(unfold_tensor)
235+
new_order = []
236+
for i in range(1, adv_indx_count[0]+1):
237+
new_order.append(i)
238+
new_order.append(0)
239+
for i in range(adv_indx_indices[0]+1, rank - adv_indx_count):
240+
new_order.append(i)
241+
242+
permute_order = trt.Permutation()
243+
permute_order(new_order)
244+
transpose_advanced_shuffle_layer.set_second_transpose(permute_order)
245+
set_layer_name(transpose_advanced_shuffle_layer, target, name + "_index_advanced_shuffle_transpose", source_ir)
246+
transpose_tensor = transpose_advanced_shuffle_layer.get_output(0)
247+
248+
#unfold advanced layer
249+
concat_final_tensor = []
250+
for i in range(0, adv_indx_indices[0]):
251+
current_dim = dim_tensor_list[i]
252+
concat_final_tensor.push_back(curr_dim)
253+
254+
concat_final_tensor.push_back(cum_adv_index_shape_tensor)
255+
for i in range(adv_indx_indices[0], rank):
256+
if(i not in (adv_indx_indices)):
257+
current_dim = dim_tensor_list[i]
258+
concat_final_tensor.append(current_dim)
259+
260+
concat_final_shape_layer = network.add_concatenation(concat_final_tensor)
261+
set_layer_name(concat_final_shape_layer, target, name + "_index_concat_final_shape_layer", source_ir)
262+
concat_final_tensor = concat_final_shape_layer.get_output(0)
263+
264+
unfold_advanced_shuffle_layer = network.add_shuffle(transpose_tensor)
265+
#check this
266+
reshape_layer.set_input(1, concat_final_tensor)
267+
reshape_output = reshape_layer.get_output(0)
268+
269+
else:
270+
concat_tensor= []
271+
for i in range(0, rank):
272+
if i not in adv_indx_indices:
273+
curr_dim = dim_tensor_list[i]
274+
concat_tensor.append(curr_dim)
275+
276+
concat_layer = network.add_concatenation(concat_tensor)
277+
set_layer_name(concat_layer, target, name + "_index_concat_final_shape_layer", source_ir)
278+
concat_final_tensor = concat_final_shape_layer.get_output(0)
279+
280+
reshape_layer = network.add_shuffle(gather_out)
281+
reshape_layer.setInput(1, concat_final_tensor)
282+
set_layer_name(reshape_layer, target, name + "_index_shuffle_final_shape_layer", source_ir)
283+
reshape_output = reshape_layer.get_output(0)
284+
285+
return reshape_output
286+
287+
288+
289+
290+
291+
292+
293+
294+
295+

0 commit comments

Comments
 (0)