-
-
Notifications
You must be signed in to change notification settings - Fork 7.7k
[V1] Get input tokens from scheduler #13339
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Woosuk Kwon <[email protected]>
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
cc @comaniac This PR seems to work correctly when using TP (or single GPU), but PP still generates gibberish outputs. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome! I'll test with PP later.
Signed-off-by: Woosuk Kwon <[email protected]>
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: Woosuk Kwon <[email protected]>
@comaniac @njhill @LiuXiaoxuanPKU I've update the PR with some simplification for spec decoding. PTAL. |
request.num_tokens) | ||
if num_scheduled_spec_tokens > 0: | ||
scheduled_spec_decode_tokens[request.request_id] = ( | ||
request.spec_token_ids[:num_scheduled_spec_tokens]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This part fixes a bug in #12193
# 1. Write spec_token_ids to input batch. | ||
# Step 1. Get req indices that perform spec decode and repeat | ||
# the req indices by the number of spec tokens. Note | ||
# for requests that don't perform spec decode, the | ||
# number of spec tokens is 0 and the req index is | ||
# repeated 0 times. | ||
# E.g., num_spec_tokens_list: [3, 0, 2, 0, 1] | ||
# spec_req_indices: [0, 0, 0, 2, 2, 4] | ||
spec_req_indices = np.repeat(self.arange_np[:num_reqs], | ||
num_spec_tokens_list) | ||
# spec_offsets: offsets within each spec token list. | ||
# E.g., [1, 2, 3, 1, 2, 1], TODO: avoid the for loop here | ||
spec_offsets = np.concatenate( | ||
[self.arange_np[1:val + 1] for val in num_spec_tokens_list]) | ||
# spec_seq_offsets: offsets within each sequence. | ||
# E.g., num_computed_tokens_cpu: [1, 4, 3, 6, 2] | ||
# after repeating: [1, 1, 1, 3, 3, 2] | ||
# spec_seq_offsets: [1, 1, 1, 3, 3, 2] + [1, 2, 3, 1, 2, 1] | ||
# = [2, 3, 4, 4, 5, 3] | ||
spec_seq_offsets = np.repeat( | ||
self.input_batch.num_computed_tokens_cpu[:num_reqs], | ||
num_spec_tokens_list) + spec_offsets | ||
# cumsums_spec_offsets: [0, 0, 0, 2M, 2M, 4M] + [2, 3, 4, 4, 5, 3] | ||
cumsums_spec_offsets = ( | ||
spec_seq_offsets + | ||
spec_req_indices * self.input_batch.token_ids_cpu.shape[1]) | ||
cumsums_spec_offsets = torch.from_numpy(cumsums_spec_offsets).to( | ||
torch.int64) | ||
all_spec_token_ids = torch.tensor(all_spec_token_ids, | ||
device="cpu", | ||
dtype=self.input_ids_cpu.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This part can be skipped as we insert spec_token_ids
into token_ids_cpu
and treat them as regular input tokens.
# Step 2. Write spec token ids to input_ids_cpu. | ||
self.input_batch.token_ids_cpu_tensor.flatten().scatter_( | ||
0, cumsums_spec_offsets, all_spec_token_ids) | ||
|
||
# 2. Get spec decode logits indices. | ||
# E.g., num_scheduled_tokens: [4, 100, 3, 100, 2] | ||
# cu_num_tokens: [4, 104, 107, 207, 209] | ||
# num_spec_tokens_list: [3, 0, 2, 0, 1] | ||
# num_sampled_tokens: [4, 1, 3, 1, 2] | ||
# spec_decode_logits_indices: | ||
# [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208] | ||
num_spec_tokens_np = np.array(num_spec_tokens_list, dtype=np.int32) | ||
num_sampled_tokens = num_spec_tokens_np + 1 | ||
# logits_start_loc: [0, 103, 104, 206, 207] | ||
logits_start_loc = cu_num_tokens - num_sampled_tokens | ||
# [0, 103, 104, 206, 207] -> | ||
# [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] | ||
logits_start_loc = np.repeat(logits_start_loc, num_sampled_tokens) | ||
# The following three lines: | ||
# [4, 1, 3, 1, 2] -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1] | ||
# Step 1. [4, 1, 3, 1, 2] -> [4, 5, 8, 9, 11] | ||
cu_num_sampled_tokens = np.cumsum(num_sampled_tokens) | ||
# Step 2. [4, 5, 8, 9, 11] -> [0, 4, 5, 8, 9] | ||
# -> [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9] | ||
cumsums_sampled_offsets = np.repeat( | ||
cu_num_sampled_tokens - num_sampled_tokens, num_sampled_tokens) | ||
# Step 3. [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] | ||
# - [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9] | ||
# -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1] | ||
total_num_sampled_tokens = num_sampled_tokens.sum() | ||
sampled_arange = (self.arange_np[:total_num_sampled_tokens] - | ||
cumsums_sampled_offsets) | ||
|
||
# [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] -> | ||
# [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208] | ||
spec_decode_logits_indices = logits_start_loc + sampled_arange |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This part was moved to a separate method for better readability.
Signed-off-by: Woosuk Kwon <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall logic looks good to me but I left two comments - PTAL!
assert all( | ||
req_id is not None for req_id in | ||
self.input_batch.req_ids[:num_reqs]), "req_ids contains None" | ||
req_ids = cast(List[str], self.input_batch.req_ids[:num_reqs]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought casting will be much faster? Is there a reason why you did the loop & append for req_ids
instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree.
Also since it's a list I think it's probably better to not preallocate to the max size like we do for tensors. The list can grow/shrink as needed. So IMO we could change that in the input batch (I actually did that in #13244). This way we can also just keep the type as List[str]
. (not suggesting for this PR though...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regarding this particular line, I think the proposed change would be slightly faster because the original code creates the list three times (two from req_ids[:num_reqs]
and another from the list comprehension for all
) while the proposed change creates only one.
Agreed with @njhill. All of these are quite hacky and unnecessarily complex. req_ids
should be fixed by #13244.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks great to me, it's also a nice simplification!
Only minor suggestions
assert all( | ||
req_id is not None for req_id in | ||
self.input_batch.req_ids[:num_reqs]), "req_ids contains None" | ||
req_ids = cast(List[str], self.input_batch.req_ids[:num_reqs]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree.
Also since it's a list I think it's probably better to not preallocate to the max size like we do for tensors. The list can grow/shrink as needed. So IMO we could change that in the input batch (I actually did that in #13244). This way we can also just keep the type as List[str]
. (not suggesting for this PR though...)
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]> Signed-off-by: Louis Ulmer <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
This PR changes the scheduler and model runner so that the model runner gets the input token IDs from the scheduler. This change is especially useful when the token IDs are not generated by the model runner (e.g., non-last ranks in PP).