Skip to content

[BugFix][V1] Fix int32 token index overflow when preparing input ids #16806

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 23, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,11 @@ def __init__(
pin_memory=self.pin_memory)
self.seq_lens_np = self.seq_lens_cpu.numpy()

max_token_idx = self.max_num_reqs * self.max_model_len - 1
# if max token idx exceeds int32 max, use int64 to avoid overflow
self.token_indices_dtype = np.int32 \
if max_token_idx <= np.iinfo(np.int32).max else np.int64

def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
"""Update the cached states and the persistent batch with the scheduler
output.
Expand Down Expand Up @@ -521,8 +526,10 @@ def _prepare_inputs(
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
# where M is the max_model_len.
# For long context, may need to cast to int64 to avoid overflow
token_indices = (positions_np +
req_indices * self.input_batch.token_ids_cpu.shape[1])
req_indices.astype(self.token_indices_dtype) *
self.input_batch.token_ids_cpu.shape[1])

# NOTE(woosuk): We use torch.index_select instead of np.take here
# because torch.index_select is much faster than np.take for large
Expand Down
9 changes: 8 additions & 1 deletion vllm/v1/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,11 @@ def __init__(
self.arange_np = np.arange(self.max_num_tokens, dtype=np.int32)
self.num_reqs_paddings = _get_req_paddings(
min_req_size=MIN_NUM_SEQS, max_req_size=self.max_num_reqs)

max_token_idx = self.max_num_reqs * self.max_model_len - 1
# if max token idx exceeds int32 max, use int64 to avoid overflow
self.token_indices_dtype = np.int32 \
if max_token_idx <= np.iinfo(np.int32).max else np.int64

def _update_num_xla_graphs(self, case_str):
check_comp = self.check_recompilation and not self.enforce_eager
Expand Down Expand Up @@ -457,8 +462,10 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
# where M is the max_model_len.
# For long context, may need to cast to int64 to avoid overflow
token_indices = (positions_np +
req_indices * self.input_batch.token_ids_cpu.shape[1])
req_indices.astype(self.token_indices_dtype) *
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we just create req_indices as token_indices_dtype instead of upcasting here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

req_indices may be used in the following logic.

Another idea is always having int64 index for such use cases. But probably require a bigger change / refactor, thoughts, @WoosukKwon ? Shall we just land this fix, and think about a more thorough check in the following PR?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can create self.arange_np with dtype self.token_indices_dtype here. Increase in host memory is negligible even at 10M and we don't pay the cast at each scheduler step.
Computation is in i64 anyways as self.positions_cpu is already at i64, so that wouldn't be the only instance.

Copy link
Collaborator Author

@sarckk sarckk Apr 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense, maybe we can just keep self.arange_np dtype in int64 similar to self.positions_cpu?

self.input_batch.token_ids_cpu.shape[1])

# NOTE(woosuk): We use torch.index_select instead of np.take here
# because torch.index_select is much faster than np.take for large
Expand Down
Loading