Skip to content

Commit 7c1a09c

Browse files
committed
Cast to avoid overflow instead of storing as int64
1 parent 86c7260 commit 7c1a09c

File tree

1 file changed

+9
-10
lines changed

1 file changed

+9
-10
lines changed

vllm/v1/worker/gpu_model_runner.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -241,18 +241,10 @@ def __init__(
241241
device=self.device)
242242

243243
# OPTIMIZATION: Cache the tensors rather than creating them every step.
244-
# For long context, may need to store int64 so max idx doesn't overflow
245-
# token_indices are calculated by adding (req_idx * max_model_len)
246-
# to per-request indices e.g. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
247-
# -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
248-
# where M is the max_model_len.
249-
max_token_idx = self.max_num_tokens + self.max_num_reqs * \
250-
self.max_model_len
251244
self.arange_np = np.arange(max(self.max_num_reqs + 1,
252245
self.max_model_len,
253246
self.max_num_tokens),
254-
dtype=np.int32 if max_token_idx <= np.iinfo(
255-
np.int32).max else np.int64)
247+
dtype=np.int32)
256248

257249
# NOTE(woosuk): These tensors are "stateless", i.e., they are literally
258250
# a faster version of creating a new tensor every time. Thus, we should
@@ -283,6 +275,11 @@ def __init__(
283275
pin_memory=self.pin_memory)
284276
self.seq_lens_np = self.seq_lens_cpu.numpy()
285277

278+
max_token_idx = self.max_num_reqs * self.max_model_len - 1
279+
# if max token idx exceeds int32 max, use int64 to avoid overflow
280+
self.token_indices_dtype = np.int32 \
281+
if max_token_idx <= np.iinfo(np.int32).max else np.int64
282+
286283
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
287284
"""Update the cached states and the persistent batch with the scheduler
288285
output.
@@ -530,8 +527,10 @@ def _prepare_inputs(
530527
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
531528
# -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
532529
# where M is the max_model_len.
530+
# For long context, may need to cast to int64 to avoid overflow
533531
token_indices = (positions_np +
534-
req_indices * self.input_batch.token_ids_cpu.shape[1])
532+
req_indices.astype(self.token_indices_dtype) *
533+
self.input_batch.token_ids_cpu.shape[1])
535534

536535
# NOTE(woosuk): We use torch.index_select instead of np.take here
537536
# because torch.index_select is much faster than np.take for large

0 commit comments

Comments
 (0)