@@ -81,30 +81,21 @@ def index(
81
81
source_ir : Optional [SourceIR ],
82
82
name : str ,
83
83
input : TRTTensor ,
84
- index : Union [
85
- TRTTensor ,
86
- Sequence [TRTTensor ],
87
- np .ndarray ,
88
- Sequence [np .ndarray ],
89
- torch .Tensor ,
90
- Sequence [torch .Tensor ],
91
- ],
84
+ index : Sequence [Union [TRTTensor , np .ndarray , torch .Tensor ]],
92
85
) -> TRTTensor :
93
86
adv_indx_indices = []
94
87
tensor_indices = []
95
- # _LOGGER.debug(f"The index shape is {index.shape}")
96
88
# check if the input is dynamic
97
89
dynamic_shape = has_dynamic_shape (input .shape )
98
90
# is_numpy is a flag to specify if all the indices are numpy or torchTensor.
99
91
# If any is not this flag will be set to False
100
92
is_numpy = True
101
- _LOGGER .debug (f"Checking for the is_numpy flag" )
102
- for i , ind in enumerate (index ):
103
- if ind is None :
104
- continue
105
- if not (isinstance (ind , torch .Tensor ) or isinstance (ind , np .ndarray )):
106
- is_numpy = False
107
- break
93
+ _LOGGER .debug (
94
+ f"Determining whether aten.index constant-index optimization can be invoked"
95
+ )
96
+ is_numpy = all (
97
+ isinstance (ind , (torch .Tensor , np .ndarray )) for ind in index if ind is not None
98
+ )
108
99
# here we need to check if all the index are broadcastable
109
100
# if no, then we need to broadcast
110
101
last_index = None
@@ -117,7 +108,6 @@ def index(
117
108
# other cases are kept as TRTTensor
118
109
if is_numpy :
119
110
ind = to_numpy (ind )
120
- is_numpy = True
121
111
else :
122
112
ind = get_trt_tensor (ctx , ind , name + f"_parameter_to_fp32_tensor_{ i } " )
123
113
if last_index is not None :
@@ -156,9 +146,7 @@ def index(
156
146
for i in range (rank ):
157
147
dim = input_shape [i ]
158
148
dim_tensor = get_trt_tensor (ctx , dim , name + f"_individual_dim_{ i } " )
159
- # dim_tensor_list is a list of tensors or numpy
160
- if is_numpy :
161
- dim_list .append (dim )
149
+ # dim_tensor_list is a list of tensors
162
150
dim_tensor_list .append (dim_tensor )
163
151
164
152
# for cases like
@@ -211,12 +199,12 @@ def index(
211
199
# tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)), ind_i is input indices[i], x_j is the
212
200
# // j dimension of input x.
213
201
if is_numpy :
214
- multiplier = dim_list [adv_indx_indices [adv_indx_count - 1 ]]
202
+ multiplier = input_shape [adv_indx_indices [adv_indx_count - 1 ]]
215
203
cum_adv_index = tensor_indices [adv_indx_count - 1 ]
216
204
for i in range (adv_indx_count - 2 , - 1 , - 1 ):
217
205
adv_index = multiplier * tensor_indices [i ]
218
206
cum_adv_index = cum_adv_index + adv_index
219
- multiplier = multiplier * dim_list [adv_indx_indices [i ]]
207
+ multiplier = multiplier * input_shape [adv_indx_indices [i ]]
220
208
cum_adv_index = get_trt_tensor (
221
209
ctx , cum_adv_index , name + f"_index_sum_intermediate"
222
210
)
0 commit comments