Skip to content

[V1] Add disable_chunked_mm_input arg to disable partial mm input prefill #15837

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 8, 2025

Conversation

mgoin
Copy link
Member

@mgoin mgoin commented Mar 31, 2025

Introduces a disable_chunked_mm_input argument to SchedulerConfig that can prevent partial scheduling of tokens from a multimodal input item, used in V1. If the scheduled range would only cover part of the mm input, roll back to only schedule the tokens before the mm item.

This ensures that if a request has a mixed prompt (like text tokens TTTT followed by image tokens IIIIIIIIII) where only some image tokens can be scheduled (like TTTTIIIII, leaving IIIII for the next step), it will be scheduled as TTTT in one step and IIIIIIIIII in the next.

EDIT added context:
This is needed because the _gather_encoder_outputs function poses a problem in the TPU model runner when chunking through a multimodal item:

end_idx = min(
num_computed_tokens - start_pos + num_scheduled_tokens,
num_encoder_tokens)
assert start_idx < end_idx
assert req_id in self.encoder_cache
assert i in self.encoder_cache[req_id]
encoder_output = self.encoder_cache[req_id][i]
encoder_outputs.append(encoder_output[start_idx:end_idx])

The last line encoder_output[start_idx:end_idx] will slice an on-device tensor with varying shape, triggering recompilation. Padding here is non-obvious because image features have to be aligned with image placeholders in input_ids for merge_multimodal_embeddings. So I think it is natural to allow for the disabling of chunking within multimodal items.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label Mar 31, 2025
@WoosukKwon
Copy link
Collaborator

What is this change for?

@mgoin
Copy link
Member Author

mgoin commented Apr 1, 2025

@WoosukKwon This is needed because the _gather_encoder_outputs function poses a problem in the TPU model runner when chunking through a multimodal item:

end_idx = min(
num_computed_tokens - start_pos + num_scheduled_tokens,
num_encoder_tokens)
assert start_idx < end_idx
assert req_id in self.encoder_cache
assert i in self.encoder_cache[req_id]
encoder_output = self.encoder_cache[req_id][i]
encoder_outputs.append(encoder_output[start_idx:end_idx])

The last line encoder_output[start_idx:end_idx] will slice an on-device tensor with varying shape, triggering recompilation. Padding here is non-obvious because image features have to be aligned with image placeholders in input_ids for merge_multimodal_embeddings. So I think it is natural to allow for the disabling of chunking within multimodal items.

Copy link

mergify bot commented Apr 1, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @mgoin.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Apr 1, 2025
@WoosukKwon
Copy link
Collaborator

@mgoin Thanks for the explanation.
I feel like the idea itself makes sense, but I might have to think more whether there's any edge case.
BTW, I think we shouldn't add an env variable for this kind of case. vLLM should do this automatically, or if we still want to provide an option, we should provide an engine arg (like disable_custom_all_reduce) rather than an env variable.

@mergify mergify bot removed the needs-rebase label Apr 3, 2025
@mgoin mgoin changed the title Add flag to disable partial mm input chunked prefill Add disable_chunked_mm_input arg to disable partial mm input prefill Apr 3, 2025
@mgoin mgoin changed the title Add disable_chunked_mm_input arg to disable partial mm input prefill [V1] Add disable_chunked_mm_input arg to disable partial mm input prefill Apr 3, 2025
@mgoin mgoin marked this pull request as ready for review April 3, 2025 14:46
@mgoin
Copy link
Member Author

mgoin commented Apr 3, 2025

I think we need to add a check that max_num_batched_tokens is large enough to fit the largest single multimodal item, but this should be ready for consideration otherwise cc @ywang96 @DarkLight1337

@DarkLight1337
Copy link
Member

So even if #15712 is merged, it will trigger recompilation on TPU? Can we exclude _gather_encoder_outputs from the graph?

@mgoin
Copy link
Member Author

mgoin commented Apr 7, 2025

@DarkLight1337 _gather_encoder_outputs will create its own graph separate from the model forward pass. Anything that deals with tensors on device will end up creating an XLA graph, but we gain a lot by separating tricky operations that often create recompilation into smaller graphs.

cc @NickLucche is this good with you?

Copy link
Contributor

@NickLucche NickLucche left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This appears to be working really well to address the TPU issue we have. Great job @mgoin !

Copy link
Member

@DarkLight1337 DarkLight1337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DarkLight1337 _gather_encoder_outputs will create its own graph separate from the model forward pass. Anything that deals with tensors on device will end up creating an XLA graph, but we gain a lot by separating tricky operations that often create recompilation into smaller graphs.

Thanks for the explanation!

@DarkLight1337 DarkLight1337 enabled auto-merge (squash) April 8, 2025 01:59
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Apr 8, 2025
@vllm-bot vllm-bot merged commit 8e5314a into vllm-project:main Apr 8, 2025
42 of 44 checks passed
@ywang96
Copy link
Member

ywang96 commented Apr 8, 2025

Sorry I'm late to this PR, but what happens if the embedding of a multimodal data item is bigger than the max_num_batched_tokens?

@NickLucche
Copy link
Contributor

From the test it appears it's not scheduling, so we need to add that check today.

@mgoin
Copy link
Member Author

mgoin commented Apr 8, 2025

I meant to add that before landing, sorry I didn't realize that auto merge was on. Will push this up today

@DarkLight1337
Copy link
Member

My bad, I thought the PR was already ready

nishith-fujitsu pushed a commit to nishith-fujitsu/vllm that referenced this pull request Apr 9, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants