Skip to content

Commit 3e4363b

Browse files
committed
Addressing review comments
1 parent 0fc9c75 commit 3e4363b

File tree

1 file changed

+10
-22
lines changed
  • py/torch_tensorrt/dynamo/conversion/impl

1 file changed

+10
-22
lines changed

py/torch_tensorrt/dynamo/conversion/impl/select.py

+10-22
Original file line numberDiff line numberDiff line change
@@ -81,30 +81,21 @@ def index(
8181
source_ir: Optional[SourceIR],
8282
name: str,
8383
input: TRTTensor,
84-
index: Union[
85-
TRTTensor,
86-
Sequence[TRTTensor],
87-
np.ndarray,
88-
Sequence[np.ndarray],
89-
torch.Tensor,
90-
Sequence[torch.Tensor],
91-
],
84+
index: Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]],
9285
) -> TRTTensor:
9386
adv_indx_indices = []
9487
tensor_indices = []
95-
# _LOGGER.debug(f"The index shape is {index.shape}")
9688
# check if the input is dynamic
9789
dynamic_shape = has_dynamic_shape(input.shape)
9890
# is_numpy is a flag to specify if all the indices are numpy or torchTensor.
9991
# If any is not this flag will be set to False
10092
is_numpy = True
101-
_LOGGER.debug(f"Checking for the is_numpy flag")
102-
for i, ind in enumerate(index):
103-
if ind is None:
104-
continue
105-
if not (isinstance(ind, torch.Tensor) or isinstance(ind, np.ndarray)):
106-
is_numpy = False
107-
break
93+
_LOGGER.debug(
94+
f"Determining whether aten.index constant-index optimization can be invoked"
95+
)
96+
is_numpy = all(
97+
isinstance(ind, (torch.Tensor, np.ndarray)) for ind in index if ind is not None
98+
)
10899
# here we need to check if all the index are broadcastable
109100
# if no, then we need to broadcast
110101
last_index = None
@@ -117,7 +108,6 @@ def index(
117108
# other cases are kept as TRTTensor
118109
if is_numpy:
119110
ind = to_numpy(ind)
120-
is_numpy = True
121111
else:
122112
ind = get_trt_tensor(ctx, ind, name + f"_parameter_to_fp32_tensor_{i}")
123113
if last_index is not None:
@@ -156,9 +146,7 @@ def index(
156146
for i in range(rank):
157147
dim = input_shape[i]
158148
dim_tensor = get_trt_tensor(ctx, dim, name + f"_individual_dim_{i}")
159-
# dim_tensor_list is a list of tensors or numpy
160-
if is_numpy:
161-
dim_list.append(dim)
149+
# dim_tensor_list is a list of tensors
162150
dim_tensor_list.append(dim_tensor)
163151

164152
# for cases like
@@ -211,12 +199,12 @@ def index(
211199
# tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)), ind_i is input indices[i], x_j is the
212200
# // j dimension of input x.
213201
if is_numpy:
214-
multiplier = dim_list[adv_indx_indices[adv_indx_count - 1]]
202+
multiplier = input_shape[adv_indx_indices[adv_indx_count - 1]]
215203
cum_adv_index = tensor_indices[adv_indx_count - 1]
216204
for i in range(adv_indx_count - 2, -1, -1):
217205
adv_index = multiplier * tensor_indices[i]
218206
cum_adv_index = cum_adv_index + adv_index
219-
multiplier = multiplier * dim_list[adv_indx_indices[i]]
207+
multiplier = multiplier * input_shape[adv_indx_indices[i]]
220208
cum_adv_index = get_trt_tensor(
221209
ctx, cum_adv_index, name + f"_index_sum_intermediate"
222210
)

0 commit comments

Comments
 (0)