@@ -89,16 +89,16 @@ def index(
89
89
# if no, then we need to broadcast
90
90
91
91
last_index = None
92
- broadcast_shape_len = 0
93
92
for i , ind in enumerate (index ):
94
93
if ind is not None :
95
94
_LOGGER .debug (f"Shape of { i } index is { ind .shape } " )
96
95
adv_indx_indices .append (i )
97
96
# torch.nn.parameter.Parameter=> torch.Tensor
98
- ind = get_trt_tensor (network , ind , f"parameter_to_fp32_tensor_ { i } " )
97
+ ind = get_trt_tensor (network , ind , name + f"_parameter_to_fp32_tensor_ { i } " )
99
98
if last_index is not None :
100
- if not (broadcastable (ind , last_index )):
101
- assert "The indices should be broadcastable"
99
+ assert broadcastable (
100
+ ind , last_index
101
+ ), "The indices should be broadcastable!"
102
102
last_index = ind
103
103
tensor_indices .append (ind )
104
104
@@ -128,7 +128,7 @@ def index(
128
128
129
129
for i in range (rank ):
130
130
dim = input_shape [i ]
131
- dim_tensor = get_trt_tensor (network , dim , f"individual_dim_ { i } " )
131
+ dim_tensor = get_trt_tensor (network , dim , name + f"_individual_dim_ { i } " )
132
132
# dim_tensor_list is a list of tensors
133
133
dim_tensor_list .append (dim_tensor )
134
134
@@ -165,8 +165,8 @@ def index(
165
165
166
166
concat_tensor_layer = network .add_concatenation (
167
167
[
168
- get_trt_tensor (network , mult_d0 , "d0_shape " ),
169
- get_trt_tensor (network , mult_d1 , "d1_shape " ),
168
+ get_trt_tensor (network , mult_d0 , name + "_d0_shape " ),
169
+ get_trt_tensor (network , mult_d1 , name + "_d1_shape " ),
170
170
]
171
171
)
172
172
set_layer_name (concat_tensor_layer , target , name + "_index_Concat" , source_ir )
@@ -181,15 +181,17 @@ def index(
181
181
# tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)), ind_i is input indices[i], x_j is the
182
182
# // j dimension of input x.
183
183
multiplier = get_trt_tensor (
184
- network , dim_tensor_list [adv_indx_indices [adv_indx_count - 1 ]], "dim_last"
184
+ network ,
185
+ dim_tensor_list [adv_indx_indices [adv_indx_count - 1 ]],
186
+ name + "_dim_last" ,
185
187
)
186
188
cum_adv_index = tensor_indices [adv_indx_count - 1 ]
187
189
for i in range (adv_indx_count - 2 , - 1 , - 1 ):
188
190
adv_index = convert_binary_elementwise (
189
191
network ,
190
192
target ,
191
193
source_ir ,
192
- name + "index_intermediate " ,
194
+ name + f"_index_intermediate_ { i } " ,
193
195
trt .ElementWiseOperation .PROD ,
194
196
multiplier ,
195
197
tensor_indices [i ],
@@ -198,7 +200,7 @@ def index(
198
200
network ,
199
201
target ,
200
202
source_ir ,
201
- name + "index_sum_intermediate " ,
203
+ name + f"_index_sum_intermediate_ { i } " ,
202
204
trt .ElementWiseOperation .SUM ,
203
205
cum_adv_index ,
204
206
adv_index ,
@@ -207,7 +209,7 @@ def index(
207
209
network ,
208
210
target ,
209
211
source_ir ,
210
- name + "index_intermediate " ,
212
+ name + f"_index_intermediate_xj_ { i } " ,
211
213
trt .ElementWiseOperation .PROD ,
212
214
multiplier ,
213
215
dim_tensor_list [adv_indx_indices [i ]],
@@ -235,7 +237,9 @@ def index(
235
237
== adv_indx_indices [adv_indx_count - 1 ] - adv_indx_indices [0 ] + 1
236
238
):
237
239
_LOGGER .debug (f"The indices are continuous in this case" )
238
- concat_tensor_reshape .append (get_trt_tensor (network , - 1 , "dynamic_concat" ))
240
+ concat_tensor_reshape .append (
241
+ get_trt_tensor (network , - 1 , name + "_dynamic_concat" )
242
+ )
239
243
for i in range (0 , rank ):
240
244
if i not in adv_indx_indices :
241
245
curr_dim = dim_tensor_list [i ]
@@ -294,7 +298,7 @@ def index(
294
298
set_layer_name (
295
299
concat_final_shape_layer ,
296
300
target ,
297
- name + "_index_concat_final_shape_layer " ,
301
+ name + "_index_continuous_concat_final_shape_layer " ,
298
302
source_ir ,
299
303
)
300
304
concat_final_tensor = concat_final_shape_layer .get_output (0 )
@@ -311,17 +315,19 @@ def index(
311
315
reshape_output = unfold_advanced_shuffle_layer .get_output (0 )
312
316
313
317
else :
314
- concat_tensor = []
318
+ _LOGGER .debug (f"The indices are not continuous in this case" )
319
+ concat_final_tensor = []
320
+ concat_final_tensor .append (cum_adv_index_shape_tensor )
315
321
for i in range (0 , rank ):
316
322
if i not in adv_indx_indices :
317
323
curr_dim = dim_tensor_list [i ]
318
- concat_tensor .append (curr_dim )
324
+ concat_final_tensor .append (curr_dim )
319
325
320
- concat_layer = network .add_concatenation (concat_tensor )
326
+ concat_final_shape_layer = network .add_concatenation (concat_final_tensor )
321
327
set_layer_name (
322
- concat_layer ,
328
+ concat_final_shape_layer ,
323
329
target ,
324
- name + "_index_concat_final_shape_layer " ,
330
+ name + "_index_non_continuous_concat_final_shape_layer " ,
325
331
source_ir ,
326
332
)
327
333
concat_final_tensor = concat_final_shape_layer .get_output (0 )
@@ -331,7 +337,7 @@ def index(
331
337
set_layer_name (
332
338
reshape_layer ,
333
339
target ,
334
- name + "_index_shuffle_final_shape_layer " ,
340
+ name + "_index_non_continuous_shuffle_final_shape_layer " ,
335
341
source_ir ,
336
342
)
337
343
reshape_output = reshape_layer .get_output (0 )
0 commit comments