@@ -89,16 +89,14 @@ def index(
89
89
# if no, then we need to broadcast
90
90
91
91
last_index = None
92
- broadcast_shape_len = 0
93
92
for i , ind in enumerate (index ):
94
93
if ind is not None :
95
94
_LOGGER .debug (f"Shape of { i } index is { ind .shape } " )
96
95
adv_indx_indices .append (i )
97
96
# torch.nn.parameter.Parameter=> torch.Tensor
98
- ind = get_trt_tensor (network , ind , f"parameter_to_fp32_tensor_ { i } " )
97
+ ind = get_trt_tensor (network , ind , name + f"_parameter_to_fp32_tensor_ { i } " )
99
98
if last_index is not None :
100
- if not (broadcastable (ind , last_index )):
101
- assert "The indices should be broadcastable"
99
+ assert broadcastable (ind , last_index ), "The indices should be broadcastable!"
102
100
last_index = ind
103
101
tensor_indices .append (ind )
104
102
@@ -128,7 +126,7 @@ def index(
128
126
129
127
for i in range (rank ):
130
128
dim = input_shape [i ]
131
- dim_tensor = get_trt_tensor (network , dim , f"individual_dim_{ i } " )
129
+ dim_tensor = get_trt_tensor (network , dim , name + f"individual_dim_{ i } " )
132
130
# dim_tensor_list is a list of tensors
133
131
dim_tensor_list .append (dim_tensor )
134
132
@@ -165,8 +163,8 @@ def index(
165
163
166
164
concat_tensor_layer = network .add_concatenation (
167
165
[
168
- get_trt_tensor (network , mult_d0 , "d0_shape" ),
169
- get_trt_tensor (network , mult_d1 , "d1_shape" ),
166
+ get_trt_tensor (network , mult_d0 , name + "d0_shape" ),
167
+ get_trt_tensor (network , mult_d1 , name + "d1_shape" ),
170
168
]
171
169
)
172
170
set_layer_name (concat_tensor_layer , target , name + "_index_Concat" , source_ir )
@@ -181,15 +179,15 @@ def index(
181
179
# tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)), ind_i is input indices[i], x_j is the
182
180
# // j dimension of input x.
183
181
multiplier = get_trt_tensor (
184
- network , dim_tensor_list [adv_indx_indices [adv_indx_count - 1 ]], "dim_last"
182
+ network , dim_tensor_list [adv_indx_indices [adv_indx_count - 1 ]], name + "dim_last"
185
183
)
186
184
cum_adv_index = tensor_indices [adv_indx_count - 1 ]
187
185
for i in range (adv_indx_count - 2 , - 1 , - 1 ):
188
186
adv_index = convert_binary_elementwise (
189
187
network ,
190
188
target ,
191
189
source_ir ,
192
- name + "index_intermediate " ,
190
+ name + f"index_intermediate_ { i } " ,
193
191
trt .ElementWiseOperation .PROD ,
194
192
multiplier ,
195
193
tensor_indices [i ],
@@ -198,7 +196,7 @@ def index(
198
196
network ,
199
197
target ,
200
198
source_ir ,
201
- name + "index_sum_intermediate " ,
199
+ name + f"index_sum_intermediate_ { i } " ,
202
200
trt .ElementWiseOperation .SUM ,
203
201
cum_adv_index ,
204
202
adv_index ,
@@ -207,7 +205,7 @@ def index(
207
205
network ,
208
206
target ,
209
207
source_ir ,
210
- name + "index_intermediate " ,
208
+ name + f"index_intermediate_xj_ { i } " ,
211
209
trt .ElementWiseOperation .PROD ,
212
210
multiplier ,
213
211
dim_tensor_list [adv_indx_indices [i ]],
@@ -235,7 +233,7 @@ def index(
235
233
== adv_indx_indices [adv_indx_count - 1 ] - adv_indx_indices [0 ] + 1
236
234
):
237
235
_LOGGER .debug (f"The indices are continuous in this case" )
238
- concat_tensor_reshape .append (get_trt_tensor (network , - 1 , "dynamic_concat" ))
236
+ concat_tensor_reshape .append (get_trt_tensor (network , - 1 , name + "dynamic_concat" ))
239
237
for i in range (0 , rank ):
240
238
if i not in adv_indx_indices :
241
239
curr_dim = dim_tensor_list [i ]
@@ -294,7 +292,7 @@ def index(
294
292
set_layer_name (
295
293
concat_final_shape_layer ,
296
294
target ,
297
- name + "_index_concat_final_shape_layer " ,
295
+ name + "_index_continuous_concat_final_shape_layer " ,
298
296
source_ir ,
299
297
)
300
298
concat_final_tensor = concat_final_shape_layer .get_output (0 )
@@ -311,17 +309,19 @@ def index(
311
309
reshape_output = unfold_advanced_shuffle_layer .get_output (0 )
312
310
313
311
else :
314
- concat_tensor = []
312
+ _LOGGER .debug (f"The indices are not continuous in this case" )
313
+ concat_final_tensor = []
314
+ concat_final_tensor .append (cum_adv_index_shape_tensor )
315
315
for i in range (0 , rank ):
316
316
if i not in adv_indx_indices :
317
317
curr_dim = dim_tensor_list [i ]
318
- concat_tensor .append (curr_dim )
318
+ concat_final_tensor .append (curr_dim )
319
319
320
- concat_layer = network .add_concatenation (concat_tensor )
320
+ concat_final_shape_layer = network .add_concatenation (concat_final_tensor )
321
321
set_layer_name (
322
- concat_layer ,
322
+ concat_final_shape_layer ,
323
323
target ,
324
- name + "_index_concat_final_shape_layer " ,
324
+ name + "_index_non_continuous_concat_final_shape_layer " ,
325
325
source_ir ,
326
326
)
327
327
concat_final_tensor = concat_final_shape_layer .get_output (0 )
@@ -331,7 +331,7 @@ def index(
331
331
set_layer_name (
332
332
reshape_layer ,
333
333
target ,
334
- name + "_index_shuffle_final_shape_layer " ,
334
+ name + "_index_non_continuous_shuffle_final_shape_layer " ,
335
335
source_ir ,
336
336
)
337
337
reshape_output = reshape_layer .get_output (0 )
0 commit comments