Skip to content

Support embedding models in V1 with a dedicated model_runner #18015

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 77 commits into
base: main
Choose a base branch
from

Conversation

maxdebayser
Copy link
Contributor

@maxdebayser maxdebayser commented May 12, 2025

This is an alternative to #16188 . In that other PR, I implemented embedding model support on the same model runner as the decoder models. This had the advantage that the code changes were fairly minimal. The other advantage in my opinion is that a single model runner implementation is less likely to become stale as new features and bug fixes only need to be applied to one code base. However, there were concerns about the performance implications and code complexity of a single implementation that tries to handle all cases.

In this PR I started by reverting all changes to the GPUModelRunner and created a GPUPoolingModelRunner basically by deleting everything that was related to sampling. In this state it was already passing the embedding model unit tests but there was still a lot of duplicated or unnecessary code.

Now I'm finished with the refactoring. Basically there is now a GPUBaseModelRunner that contains the common code and the GPUModelRunner and the GPUPoolingModelRunner implement the missing pieces.

There were a few issues where @22quinn spent some time thinking about:

  • kv-cache management: For encode models this is unnecessary because the attention mask is not causal and therefore optimizations such as chunked prefill and prefix caching don't apply. However, there are encoder models with pooling based in the last hidden state where these optimizations are applicable. One example is the intfloat/e5-mistral-7b-instruct that we use in the unit tests.
  • handling of m-rope, sliding window, multi-modal...: these thing are mostly orthogonal to pooling or sampling and went into the abstract base class
  • input batch management: when chunked prefill is disabled, in pooling models each request only stays in the batch for one execute_model call. However, with chunked prefill execute_model is called several times and the same logic that is used in the sampling models applies.
  • cascade attention: this only applies to decoding.

cc: @mgoin, @WoosukKwon , @DarkLight1337

maxdebayser and others added 30 commits March 24, 2025 15:59
Encoder-only models can also benefit from the prefix caching that is
enabled by the kv cache

Signed-off-by: Max de Bayser <[email protected]>
This is only passing mypy, it hasn't been tested
yet

Signed-off-by: Max de Bayser <[email protected]>
... and disable cuda graphs for these models.

Signed-off-by: Max de Bayser <[email protected]>
Signed-off-by: Max de Bayser <[email protected]>
Signed-off-by: Max de Bayser <[email protected]>
Signed-off-by: Max de Bayser <[email protected]>
Signed-off-by: Max de Bayser <[email protected]>
Signed-off-by: Max de Bayser <[email protected]>
Signed-off-by: Max de Bayser <[email protected]>
Signed-off-by: Max de Bayser <[email protected]>
Signed-off-by: Max de Bayser <[email protected]>
Signed-off-by: Max de Bayser <[email protected]>
Signed-off-by: Max de Bayser <[email protected]>
Refactor GPU model runner into a base model runner and
a model runner for sampling and another for pooling.

Signed-off-by: Max de Bayser <[email protected]>
@maxdebayser maxdebayser changed the title [PoC] Support embedding models in V1 with a dedicated model_runner Support embedding models in V1 with a dedicated model_runner May 22, 2025
Signed-off-by: Max de Bayser <[email protected]>
@maxdebayser
Copy link
Contributor Author

There are some merge conflicts because of the KV cache group PR that was reverted, but it seems that it will be added again once the maintainer fixes the bugs that have been found after it was merged in main. So I'm going to wait a little bit before trying to solve the conflicts.

@mergify mergify bot removed the needs-rebase label May 27, 2025
Signed-off-by: Max de Bayser <[email protected]>
@maxdebayser
Copy link
Contributor Author

The kv cache group PR was redone so I've fixed the merge conflicts.

"""Tensors for pooling."""

prompt_lens: torch.Tensor
prompt_token_ids: Optional[torch.Tensor]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any benefit using torch.Tensor instead of list[int]? Same for prompt_lens

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For prompt_lens it helps that the array is already in tensor format to do things like torch.cumsum(prompt_lens, dim=0). For prompt_token_ids there is no strong reason, but it allows us to reuse the same _make_prompt_token_ids_tensor() function as in the non-pooling case.


return PoolingMetadata(
prompt_lens=torch.from_numpy(
self.num_prompt_tokens[:self.num_reqs]).to(self.device),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same q here - why convert to tensor?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See answer above.

last_token_id = request.output_token_ids[-1]
if (not sampling_params.ignore_eos
and last_token_id == request.eos_token_id):
if request.pooling_params and pooler_output is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible for pooler_output to be None here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, during chunked prefill.

@facebook-github-bot
Copy link

@22quinn has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

Copy link

mergify bot commented May 28, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @maxdebayser.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label May 28, 2025
@mergify mergify bot removed the needs-rebase label May 29, 2025
Signed-off-by: Max de Bayser <[email protected]>
Copy link

mergify bot commented May 30, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @maxdebayser.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label May 30, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: No status
Development

Successfully merging this pull request may close these issues.

3 participants