Skip to content

[V1] Get input tokens from scheduler #13339

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Feb 17, 2025
Merged

[V1] Get input tokens from scheduler #13339

merged 8 commits into from
Feb 17, 2025

Conversation

WoosukKwon
Copy link
Collaborator

@WoosukKwon WoosukKwon commented Feb 15, 2025

This PR changes the scheduler and model runner so that the model runner gets the input token IDs from the scheduler. This change is especially useful when the token IDs are not generated by the model runner (e.g., non-last ranks in PP).

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label Feb 15, 2025
@WoosukKwon
Copy link
Collaborator Author

cc @comaniac This PR seems to work correctly when using TP (or single GPU), but PP still generates gibberish outputs.

@WoosukKwon WoosukKwon added the ready ONLY add when PR is ready to merge/full CI is needed label Feb 15, 2025
Copy link
Collaborator

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! I'll test with PP later.

Signed-off-by: Woosuk Kwon <[email protected]>
Copy link

mergify bot commented Feb 16, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @WoosukKwon.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Feb 16, 2025
@mergify mergify bot removed the needs-rebase label Feb 16, 2025
@WoosukKwon
Copy link
Collaborator Author

@comaniac @njhill @LiuXiaoxuanPKU I've update the PR with some simplification for spec decoding. PTAL.

request.num_tokens)
if num_scheduled_spec_tokens > 0:
scheduled_spec_decode_tokens[request.request_id] = (
request.spec_token_ids[:num_scheduled_spec_tokens])
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part fixes a bug in #12193

Comment on lines -444 to -477
# 1. Write spec_token_ids to input batch.
# Step 1. Get req indices that perform spec decode and repeat
# the req indices by the number of spec tokens. Note
# for requests that don't perform spec decode, the
# number of spec tokens is 0 and the req index is
# repeated 0 times.
# E.g., num_spec_tokens_list: [3, 0, 2, 0, 1]
# spec_req_indices: [0, 0, 0, 2, 2, 4]
spec_req_indices = np.repeat(self.arange_np[:num_reqs],
num_spec_tokens_list)
# spec_offsets: offsets within each spec token list.
# E.g., [1, 2, 3, 1, 2, 1], TODO: avoid the for loop here
spec_offsets = np.concatenate(
[self.arange_np[1:val + 1] for val in num_spec_tokens_list])
# spec_seq_offsets: offsets within each sequence.
# E.g., num_computed_tokens_cpu: [1, 4, 3, 6, 2]
# after repeating: [1, 1, 1, 3, 3, 2]
# spec_seq_offsets: [1, 1, 1, 3, 3, 2] + [1, 2, 3, 1, 2, 1]
# = [2, 3, 4, 4, 5, 3]
spec_seq_offsets = np.repeat(
self.input_batch.num_computed_tokens_cpu[:num_reqs],
num_spec_tokens_list) + spec_offsets
# cumsums_spec_offsets: [0, 0, 0, 2M, 2M, 4M] + [2, 3, 4, 4, 5, 3]
cumsums_spec_offsets = (
spec_seq_offsets +
spec_req_indices * self.input_batch.token_ids_cpu.shape[1])
cumsums_spec_offsets = torch.from_numpy(cumsums_spec_offsets).to(
torch.int64)
all_spec_token_ids = torch.tensor(all_spec_token_ids,
device="cpu",
dtype=self.input_ids_cpu.dtype)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part can be skipped as we insert spec_token_ids into token_ids_cpu and treat them as regular input tokens.

Comment on lines -476 to -514
# Step 2. Write spec token ids to input_ids_cpu.
self.input_batch.token_ids_cpu_tensor.flatten().scatter_(
0, cumsums_spec_offsets, all_spec_token_ids)

# 2. Get spec decode logits indices.
# E.g., num_scheduled_tokens: [4, 100, 3, 100, 2]
# cu_num_tokens: [4, 104, 107, 207, 209]
# num_spec_tokens_list: [3, 0, 2, 0, 1]
# num_sampled_tokens: [4, 1, 3, 1, 2]
# spec_decode_logits_indices:
# [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
num_spec_tokens_np = np.array(num_spec_tokens_list, dtype=np.int32)
num_sampled_tokens = num_spec_tokens_np + 1
# logits_start_loc: [0, 103, 104, 206, 207]
logits_start_loc = cu_num_tokens - num_sampled_tokens
# [0, 103, 104, 206, 207] ->
# [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207]
logits_start_loc = np.repeat(logits_start_loc, num_sampled_tokens)
# The following three lines:
# [4, 1, 3, 1, 2] -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
# Step 1. [4, 1, 3, 1, 2] -> [4, 5, 8, 9, 11]
cu_num_sampled_tokens = np.cumsum(num_sampled_tokens)
# Step 2. [4, 5, 8, 9, 11] -> [0, 4, 5, 8, 9]
# -> [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9]
cumsums_sampled_offsets = np.repeat(
cu_num_sampled_tokens - num_sampled_tokens, num_sampled_tokens)
# Step 3. [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
# - [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9]
# -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
total_num_sampled_tokens = num_sampled_tokens.sum()
sampled_arange = (self.arange_np[:total_num_sampled_tokens] -
cumsums_sampled_offsets)

# [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] ->
# [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
spec_decode_logits_indices = logits_start_loc + sampled_arange
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part was moved to a separate method for better readability.

Copy link
Member

@ywang96 ywang96 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall logic looks good to me but I left two comments - PTAL!

assert all(
req_id is not None for req_id in
self.input_batch.req_ids[:num_reqs]), "req_ids contains None"
req_ids = cast(List[str], self.input_batch.req_ids[:num_reqs])
Copy link
Member

@ywang96 ywang96 Feb 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought casting will be much faster? Is there a reason why you did the loop & append for req_ids instead?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree.

Also since it's a list I think it's probably better to not preallocate to the max size like we do for tensors. The list can grow/shrink as needed. So IMO we could change that in the input batch (I actually did that in #13244). This way we can also just keep the type as List[str]. (not suggesting for this PR though...)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding this particular line, I think the proposed change would be slightly faster because the original code creates the list three times (two from req_ids[:num_reqs] and another from the list comprehension for all) while the proposed change creates only one.

Agreed with @njhill. All of these are quite hacky and unnecessarily complex. req_ids should be fixed by #13244.

Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks great to me, it's also a nice simplification!

Only minor suggestions

assert all(
req_id is not None for req_id in
self.input_batch.req_ids[:num_reqs]), "req_ids contains None"
req_ids = cast(List[str], self.input_batch.req_ids[:num_reqs])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree.

Also since it's a list I think it's probably better to not preallocate to the max size like we do for tensors. The list can grow/shrink as needed. So IMO we could change that in the input batch (I actually did that in #13244). This way we can also just keep the type as List[str]. (not suggesting for this PR though...)

@WoosukKwon WoosukKwon merged commit 4c21ce9 into main Feb 17, 2025
38 of 44 checks passed
@WoosukKwon WoosukKwon deleted the v1-scheduler-input branch February 17, 2025 19:01
xjpang pushed a commit to xjpang/vllm that referenced this pull request Feb 20, 2025
Akshat-Tripathi pushed a commit to krai/vllm that referenced this pull request Mar 3, 2025
lulmer pushed a commit to lulmer/vllm that referenced this pull request Apr 7, 2025
shreyankg pushed a commit to shreyankg/vllm that referenced this pull request May 3, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants