@@ -241,18 +241,10 @@ def __init__(
241
241
device = self .device )
242
242
243
243
# OPTIMIZATION: Cache the tensors rather than creating them every step.
244
- # For long context, may need to store int64 so max idx doesn't overflow
245
- # token_indices are calculated by adding (req_idx * max_model_len)
246
- # to per-request indices e.g. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
247
- # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
248
- # where M is the max_model_len.
249
- max_token_idx = self .max_num_tokens + self .max_num_reqs * \
250
- self .max_model_len
251
244
self .arange_np = np .arange (max (self .max_num_reqs + 1 ,
252
245
self .max_model_len ,
253
246
self .max_num_tokens ),
254
- dtype = np .int32 if max_token_idx <= np .iinfo (
255
- np .int32 ).max else np .int64 )
247
+ dtype = np .int32 )
256
248
257
249
# NOTE(woosuk): These tensors are "stateless", i.e., they are literally
258
250
# a faster version of creating a new tensor every time. Thus, we should
@@ -283,6 +275,11 @@ def __init__(
283
275
pin_memory = self .pin_memory )
284
276
self .seq_lens_np = self .seq_lens_cpu .numpy ()
285
277
278
+ max_token_idx = self .max_num_reqs * self .max_model_len - 1
279
+ # if max token idx exceeds int32 max, use int64 to avoid overflow
280
+ self .token_indices_dtype = np .int32 \
281
+ if max_token_idx <= np .iinfo (np .int32 ).max else np .int64
282
+
286
283
def _update_states (self , scheduler_output : "SchedulerOutput" ) -> None :
287
284
"""Update the cached states and the persistent batch with the scheduler
288
285
output.
@@ -530,8 +527,10 @@ def _prepare_inputs(
530
527
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
531
528
# -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
532
529
# where M is the max_model_len.
530
+ # For long context, may need to cast to int64 to avoid overflow
533
531
token_indices = (positions_np +
534
- req_indices * self .input_batch .token_ids_cpu .shape [1 ])
532
+ req_indices .astype (self .token_indices_dtype ) *
533
+ self .input_batch .token_ids_cpu .shape [1 ])
535
534
536
535
# NOTE(woosuk): We use torch.index_select instead of np.take here
537
536
# because torch.index_select is much faster than np.take for large
0 commit comments