Skip to content

Commit 7d215af

Browse files
committed
Name change of the layers and debugging the cases where indices are non continuous
1 parent 0691d05 commit 7d215af

File tree

1 file changed

+19
-19
lines changed
  • py/torch_tensorrt/dynamo/conversion/impl

1 file changed

+19
-19
lines changed

py/torch_tensorrt/dynamo/conversion/impl/select.py

+19-19
Original file line numberDiff line numberDiff line change
@@ -89,16 +89,14 @@ def index(
8989
# if no, then we need to broadcast
9090

9191
last_index = None
92-
broadcast_shape_len = 0
9392
for i, ind in enumerate(index):
9493
if ind is not None:
9594
_LOGGER.debug(f"Shape of {i} index is {ind.shape}")
9695
adv_indx_indices.append(i)
9796
# torch.nn.parameter.Parameter=> torch.Tensor
98-
ind = get_trt_tensor(network, ind, f"parameter_to_fp32_tensor_{i}")
97+
ind = get_trt_tensor(network, ind, name + f"_parameter_to_fp32_tensor_{i}")
9998
if last_index is not None:
100-
if not (broadcastable(ind, last_index)):
101-
assert "The indices should be broadcastable"
99+
assert broadcastable(ind, last_index), "The indices should be broadcastable!"
102100
last_index = ind
103101
tensor_indices.append(ind)
104102

@@ -128,7 +126,7 @@ def index(
128126

129127
for i in range(rank):
130128
dim = input_shape[i]
131-
dim_tensor = get_trt_tensor(network, dim, f"individual_dim_{i}")
129+
dim_tensor = get_trt_tensor(network, dim, name + f"individual_dim_{i}")
132130
# dim_tensor_list is a list of tensors
133131
dim_tensor_list.append(dim_tensor)
134132

@@ -165,8 +163,8 @@ def index(
165163

166164
concat_tensor_layer = network.add_concatenation(
167165
[
168-
get_trt_tensor(network, mult_d0, "d0_shape"),
169-
get_trt_tensor(network, mult_d1, "d1_shape"),
166+
get_trt_tensor(network, mult_d0, name + "d0_shape"),
167+
get_trt_tensor(network, mult_d1, name + "d1_shape"),
170168
]
171169
)
172170
set_layer_name(concat_tensor_layer, target, name + "_index_Concat", source_ir)
@@ -181,15 +179,15 @@ def index(
181179
# tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)), ind_i is input indices[i], x_j is the
182180
# // j dimension of input x.
183181
multiplier = get_trt_tensor(
184-
network, dim_tensor_list[adv_indx_indices[adv_indx_count - 1]], "dim_last"
182+
network, dim_tensor_list[adv_indx_indices[adv_indx_count - 1]], name + "dim_last"
185183
)
186184
cum_adv_index = tensor_indices[adv_indx_count - 1]
187185
for i in range(adv_indx_count - 2, -1, -1):
188186
adv_index = convert_binary_elementwise(
189187
network,
190188
target,
191189
source_ir,
192-
name + "index_intermediate",
190+
name + f"index_intermediate_{i}",
193191
trt.ElementWiseOperation.PROD,
194192
multiplier,
195193
tensor_indices[i],
@@ -198,7 +196,7 @@ def index(
198196
network,
199197
target,
200198
source_ir,
201-
name + "index_sum_intermediate",
199+
name + f"index_sum_intermediate_{i}",
202200
trt.ElementWiseOperation.SUM,
203201
cum_adv_index,
204202
adv_index,
@@ -207,7 +205,7 @@ def index(
207205
network,
208206
target,
209207
source_ir,
210-
name + "index_intermediate",
208+
name + f"index_intermediate_xj_{i}",
211209
trt.ElementWiseOperation.PROD,
212210
multiplier,
213211
dim_tensor_list[adv_indx_indices[i]],
@@ -235,7 +233,7 @@ def index(
235233
== adv_indx_indices[adv_indx_count - 1] - adv_indx_indices[0] + 1
236234
):
237235
_LOGGER.debug(f"The indices are continuous in this case")
238-
concat_tensor_reshape.append(get_trt_tensor(network, -1, "dynamic_concat"))
236+
concat_tensor_reshape.append(get_trt_tensor(network, -1, name + "dynamic_concat"))
239237
for i in range(0, rank):
240238
if i not in adv_indx_indices:
241239
curr_dim = dim_tensor_list[i]
@@ -294,7 +292,7 @@ def index(
294292
set_layer_name(
295293
concat_final_shape_layer,
296294
target,
297-
name + "_index_concat_final_shape_layer",
295+
name + "_index_continuous_concat_final_shape_layer",
298296
source_ir,
299297
)
300298
concat_final_tensor = concat_final_shape_layer.get_output(0)
@@ -311,17 +309,19 @@ def index(
311309
reshape_output = unfold_advanced_shuffle_layer.get_output(0)
312310

313311
else:
314-
concat_tensor = []
312+
_LOGGER.debug(f"The indices are not continuous in this case")
313+
concat_final_tensor = []
314+
concat_final_tensor.append(cum_adv_index_shape_tensor)
315315
for i in range(0, rank):
316316
if i not in adv_indx_indices:
317317
curr_dim = dim_tensor_list[i]
318-
concat_tensor.append(curr_dim)
318+
concat_final_tensor.append(curr_dim)
319319

320-
concat_layer = network.add_concatenation(concat_tensor)
320+
concat_final_shape_layer = network.add_concatenation(concat_final_tensor)
321321
set_layer_name(
322-
concat_layer,
322+
concat_final_shape_layer,
323323
target,
324-
name + "_index_concat_final_shape_layer",
324+
name + "_index_non_continuous_concat_final_shape_layer",
325325
source_ir,
326326
)
327327
concat_final_tensor = concat_final_shape_layer.get_output(0)
@@ -331,7 +331,7 @@ def index(
331331
set_layer_name(
332332
reshape_layer,
333333
target,
334-
name + "_index_shuffle_final_shape_layer",
334+
name + "_index_non_continuous_shuffle_final_shape_layer",
335335
source_ir,
336336
)
337337
reshape_output = reshape_layer.get_output(0)

0 commit comments

Comments
 (0)