Skip to content

Commit fcbd767

Browse files
committed
Name change of the layers and debugging the cases where indices are non continuous
1 parent 0691d05 commit fcbd767

File tree

1 file changed

+25
-19
lines changed
  • py/torch_tensorrt/dynamo/conversion/impl

1 file changed

+25
-19
lines changed

py/torch_tensorrt/dynamo/conversion/impl/select.py

+25-19
Original file line numberDiff line numberDiff line change
@@ -89,16 +89,16 @@ def index(
8989
# if no, then we need to broadcast
9090

9191
last_index = None
92-
broadcast_shape_len = 0
9392
for i, ind in enumerate(index):
9493
if ind is not None:
9594
_LOGGER.debug(f"Shape of {i} index is {ind.shape}")
9695
adv_indx_indices.append(i)
9796
# torch.nn.parameter.Parameter=> torch.Tensor
98-
ind = get_trt_tensor(network, ind, f"parameter_to_fp32_tensor_{i}")
97+
ind = get_trt_tensor(network, ind, name + f"_parameter_to_fp32_tensor_{i}")
9998
if last_index is not None:
100-
if not (broadcastable(ind, last_index)):
101-
assert "The indices should be broadcastable"
99+
assert broadcastable(
100+
ind, last_index
101+
), "The indices should be broadcastable!"
102102
last_index = ind
103103
tensor_indices.append(ind)
104104

@@ -128,7 +128,7 @@ def index(
128128

129129
for i in range(rank):
130130
dim = input_shape[i]
131-
dim_tensor = get_trt_tensor(network, dim, f"individual_dim_{i}")
131+
dim_tensor = get_trt_tensor(network, dim, name + f"_individual_dim_{i}")
132132
# dim_tensor_list is a list of tensors
133133
dim_tensor_list.append(dim_tensor)
134134

@@ -165,8 +165,8 @@ def index(
165165

166166
concat_tensor_layer = network.add_concatenation(
167167
[
168-
get_trt_tensor(network, mult_d0, "d0_shape"),
169-
get_trt_tensor(network, mult_d1, "d1_shape"),
168+
get_trt_tensor(network, mult_d0, name + "_d0_shape"),
169+
get_trt_tensor(network, mult_d1, name + "_d1_shape"),
170170
]
171171
)
172172
set_layer_name(concat_tensor_layer, target, name + "_index_Concat", source_ir)
@@ -181,15 +181,17 @@ def index(
181181
# tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)), ind_i is input indices[i], x_j is the
182182
# // j dimension of input x.
183183
multiplier = get_trt_tensor(
184-
network, dim_tensor_list[adv_indx_indices[adv_indx_count - 1]], "dim_last"
184+
network,
185+
dim_tensor_list[adv_indx_indices[adv_indx_count - 1]],
186+
name + "_dim_last",
185187
)
186188
cum_adv_index = tensor_indices[adv_indx_count - 1]
187189
for i in range(adv_indx_count - 2, -1, -1):
188190
adv_index = convert_binary_elementwise(
189191
network,
190192
target,
191193
source_ir,
192-
name + "index_intermediate",
194+
name + f"_index_intermediate_{i}",
193195
trt.ElementWiseOperation.PROD,
194196
multiplier,
195197
tensor_indices[i],
@@ -198,7 +200,7 @@ def index(
198200
network,
199201
target,
200202
source_ir,
201-
name + "index_sum_intermediate",
203+
name + f"_index_sum_intermediate_{i}",
202204
trt.ElementWiseOperation.SUM,
203205
cum_adv_index,
204206
adv_index,
@@ -207,7 +209,7 @@ def index(
207209
network,
208210
target,
209211
source_ir,
210-
name + "index_intermediate",
212+
name + f"_index_intermediate_xj_{i}",
211213
trt.ElementWiseOperation.PROD,
212214
multiplier,
213215
dim_tensor_list[adv_indx_indices[i]],
@@ -235,7 +237,9 @@ def index(
235237
== adv_indx_indices[adv_indx_count - 1] - adv_indx_indices[0] + 1
236238
):
237239
_LOGGER.debug(f"The indices are continuous in this case")
238-
concat_tensor_reshape.append(get_trt_tensor(network, -1, "dynamic_concat"))
240+
concat_tensor_reshape.append(
241+
get_trt_tensor(network, -1, name + "_dynamic_concat")
242+
)
239243
for i in range(0, rank):
240244
if i not in adv_indx_indices:
241245
curr_dim = dim_tensor_list[i]
@@ -294,7 +298,7 @@ def index(
294298
set_layer_name(
295299
concat_final_shape_layer,
296300
target,
297-
name + "_index_concat_final_shape_layer",
301+
name + "_index_continuous_concat_final_shape_layer",
298302
source_ir,
299303
)
300304
concat_final_tensor = concat_final_shape_layer.get_output(0)
@@ -311,17 +315,19 @@ def index(
311315
reshape_output = unfold_advanced_shuffle_layer.get_output(0)
312316

313317
else:
314-
concat_tensor = []
318+
_LOGGER.debug(f"The indices are not continuous in this case")
319+
concat_final_tensor = []
320+
concat_final_tensor.append(cum_adv_index_shape_tensor)
315321
for i in range(0, rank):
316322
if i not in adv_indx_indices:
317323
curr_dim = dim_tensor_list[i]
318-
concat_tensor.append(curr_dim)
324+
concat_final_tensor.append(curr_dim)
319325

320-
concat_layer = network.add_concatenation(concat_tensor)
326+
concat_final_shape_layer = network.add_concatenation(concat_final_tensor)
321327
set_layer_name(
322-
concat_layer,
328+
concat_final_shape_layer,
323329
target,
324-
name + "_index_concat_final_shape_layer",
330+
name + "_index_non_continuous_concat_final_shape_layer",
325331
source_ir,
326332
)
327333
concat_final_tensor = concat_final_shape_layer.get_output(0)
@@ -331,7 +337,7 @@ def index(
331337
set_layer_name(
332338
reshape_layer,
333339
target,
334-
name + "_index_shuffle_final_shape_layer",
340+
name + "_index_non_continuous_shuffle_final_shape_layer",
335341
source_ir,
336342
)
337343
reshape_output = reshape_layer.get_output(0)

0 commit comments

Comments
 (0)