Skip to content

[RFC][V1] LogitsProcessor interface #13360

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

njhill
Copy link
Member

@njhill njhill commented Feb 16, 2025

Proposed abstraction for how to handle sampling parameters in relation to the persistent batch. This interface could then be used as an extension point for custom logits processors.

Key goals/ideas:

  • Logits processor implementations are configured globally, we won't support per-request
  • They apply at a batch level rather than per-request to allow for / encourage vectorized application
  • Each logits processor encapsulates its own state and is responsible for updating it as needed based on notification of persistent batch updates and new output tokens each step. This minimizes the number of times tensors need to be reconstructed and updated on the GPU.

To demonstrate the idea I've implemented LPs for min_tokens, logit_bias and min_p, but if we decide to go this route it should be straightforward to refactor the others similarly.

Note this is just to discuss the general approach - it could still be simplified/refined further.

class LogitsProcessor(ABC):
    @abstractmethod
    def apply(self, logits: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError

    @abstractmethod
    def update_states(
        self,
        batch_update: Optional[BatchUpdate] = None,
    ) -> None:
        """Called when there are new output tokens, prior
        to each forward pass.
        Args:
            batch_update is non-None iff there have been
            changes to the batch makeup.
        """
        raise NotImplementedError
        
@dataclasses.dataclass
class BatchUpdate:
    # Batch indices of any removed requests.
    removed: List[int]
    # (from, to) batch indices of any requests
    # moved within the batch.
    moved: List[Tuple[int, int]]
    # (index, params, output_tok_ids) for new
    # requests added to the batch.
    #TODO may need to include one or two other things here, like prompt token ids.
    added: List[Tuple[int, SamplingParams, List[int]]]
    # The current number of requests in the batch.
    batch_size: int

@WoosukKwon @AlpinDale @houseroad

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks very good to me!

Comment on lines +61 to +65
self.min_p_gpu: torch.Tensor = torch.empty((max_num_reqs, ),
dtype=torch.float32,
device=device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we preallocate for other as well?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@WoosukKwon this could potentially be done but it's not quite a simple because the maximum size could be much larger (e.g. max_batch_size * max_logit_bias_tokens * 3 for logit bias), but they are sparse and so in practice the size would be much smaller, or more likely they wouldn't actually be used at all.

Intuitively, since we are minimizing how often these get updated, and still doing the transfer from cpu async from pinned mem, I would guess we aren't losing much by not preallocating. But I can benchmark and see. If we do it we may want to just allocate to high-water-mark so that if the feature isn't used for example then there will be no allocation.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. Sounds good. As this is a niche feature, I think we shouldn't spend too much time in it.

Copy link

mergify bot commented Feb 17, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @njhill.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

pin_memory=pin_memory)
self.min_p_cpu = self.min_p_cpu_tensor.numpy()
# Pre-allocated device tensor
self.min_p_gpu: torch.Tensor = torch.empty((max_num_reqs, ),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we call it min_p_device? We may have use case other than GPU, right?

dtype=torch.float32,
device="cpu",
pin_memory=pin_memory)
self.min_p_cpu = self.min_p_cpu_tensor.numpy()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why having min_p_cpu on top of min_p_cpu_tensor?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This mirrors the existing implementation, I think the reason is that it's much cheaper to manipulate individual elements in numpy arrays than tensors, so better to do those updates on a numpy view. But we still need the tensor to transfer to GPU.


logit_bias: List[Optional[Dict[int, float]]]
logits_procs: List[LogitsProcessor]
nongreedy_logits_procs: List[LogitsProcessor]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wondering why split into nongreedy and regular logits processors?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some logits processors don't affect greedy decoding (or if they do, only due to precision errors), so we want to only apply them when there are non-greedy (random-sampled) requests in the batch. I don't really like this name for the field but couldn't think of anything better :)

In theory, each of the logits processors keeps track of whether there are any applicable requests in the current batch and so this might not be needed. But it means they would each need to check the temperature too when requests are added. AFAIK we don't have validation to reject requests that have e.g. temp=0 and top_p=0.5 (top_p has no effect in this case).

@mmoskal
Copy link
Contributor

mmoskal commented Feb 19, 2025

@njhill shouldn't BatchUpdate also include the tokens that were sampled for each current sequence? it's not needed for the min_p/min_tokens/logit_bias ones you implemented, but would be needed for anything more complicated?

@njhill
Copy link
Member Author

njhill commented Feb 19, 2025

@njhill shouldn't BatchUpdate also include the tokens that were sampled for each current sequence? it's not needed for the min_p/min_tokens/logit_bias ones you implemented, but would be needed for anything more complicated?

So this is not necessarily the final state, but currently this it's handled via the list that's passed in the added requests. This is the list that is updated in-place with new tokens and so the impl can hang on to that if it needs to know them, and check it each time it's called. The idea is that BatchUpdate will only be present if reqs have been added or removed from the batch, otherwise this list of output ids can be checked for all of the current requests that the LP cares about. Of course this will need to be clearly documented if it remains the way that it's done :)

There will definitely need to be changes to this regardless.. like adding the prompt tokens and we could possibly include a pointer to on-device tensor of just the newly generated tokens that will be there anyhow, since it might be faster for the LP to use that directly (rather than multiple of them copying these tokens back to the GPU).

You can see an example of this for the min_tokens impl in this PR, which uses the length of that list.

Copy link

mergify bot commented Mar 21, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @njhill.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 21, 2025
@jiwangyihao
Copy link

jiwangyihao commented Mar 25, 2025

In the V1 interface, I'm curious if the following code can still be used to customize the Logits processor for a single batch?

outputs = llm.generate(
    prompts,
    SamplingParams(
        temperature=temperature,
        max_tokens=1,
        allowed_token_ids=accepted_ids
    ),
)

And if not, is there any alternative solution?

Thank you for your time.

@mmoskal
Copy link
Contributor

mmoskal commented Mar 25, 2025

You can set the guided decoding backend to "guidance" and use the following grammar, assuming your tokens are 74, 837, 1010, 1011, 1012, 1013:

start: ( <[74]> | <[837]> | <[1010-1013]> )*

for details see https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens

Make sure you have llguidance at least 0.7.10 (should be ready in 20 min); 0.7.9 as used in vllm would reject this grammar in front-end.

Of course you can also use strings in your grammar, not tokens, but then we will allow any tokenization of the given string, which may or may not be what you want.

@wlhgtc
Copy link

wlhgtc commented Mar 26, 2025

If I want to force stop thinking in an R1-like model (e.g., when prompt + outputs > 8192, force generate </think>)
how can I get prompt_length for single request?
Thank you for your time.

return logits

# Convert logits to probability distribution
probability_values = torch.nn.functional.softmax(logits, dim=-1)
Copy link

@DreamGenX DreamGenX Mar 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe you can do min p without softmax, which might be faster:
Set logits[i] to -inf when logits[i] < logits_max + ln(min_p) where logits_max is the largest logit.

You can precompute the log(min_p) in update_states

@mergify mergify bot added the tpu Related to Google TPUs label Mar 27, 2025
@mgoin mgoin self-requested a review April 14, 2025 21:36
@mergify mergify bot removed the needs-rebase label Apr 15, 2025
@njhill
Copy link
Member Author

njhill commented Apr 15, 2025

I've just rebased this PR, but haven't yet addressed some changes made since it was originally created:

  • The TPU impl has separate sampling metadata handling, we need to see how this will work with that.
  • There is a check for whether spec decode is supported based on the state of the input batch. Now that the input batch will change to contain LPs instead of the "raw" sampling param data, this check will need to be adjusted.
  • A out-of-vocab token check was added to the logit bias impl in the main branch. The re-implemented logit bias logic in this PR doesn't include that, but we should move that check to be part of the request validation anyhow (as I've now commented on that PR).

@afeldman-nm
Copy link
Contributor

jfyi I am picking up this work, see here #16728

Copy link
Contributor

@NickLucche NickLucche left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work!
I assume no in-place update operations are going to be allowed in order to not enforce any priority on the order on which logits processor in logits_procs are applied right?

The TPU impl has separate sampling metadata handling, we need to see how this will work with that.

The main thing with TPU is that we need a way to limit the logits processors to the params that are currently supported before tracing in a fixed order. With the current implementation there's no way to toggle on/off some of them so we'll have to edit logits_procs from within TPUModelRunner.
Also, we're currently replacing the update_states function with one that does not introduce op dynamism for tpus. Here the update_states is scattered across different logits processor so we would have to address each one individually.

As a rule of thumb, for TPU update_states functions are always ok when they operate on CPU tensors and then only move the result to device.

cc @yaochengji

If padding_gap == 0 then:
increase 2X each time (exponential)
else:
first increase the size to twice,
first increase the size to twice,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are these formatting changes due to..?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry these must have been inadvertent while rebasing, probably by IDE. We can revert them.

Copy link

mergify bot commented Apr 18, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @njhill.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

# TODO possibly add more direct swap operation to LPs
swaps.append((i1, input_batch.max_num_reqs))
swaps.append((i2, i1))
swaps.append((input_batch.max_num_reqs, i2))
Copy link
Collaborator

@LucasWilkinson LucasWilkinson Apr 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could maybe consider having input_batch track the swaps (i.e. calls to input_batch.swap_states) in an internal data structure that gets reset with calls to refresh_sampling_metadata, this way the if we do do TODO possibly add more direct swap operation to LPs then all the attn metadata implementations don't need to updated (FlashInfer does swapping to)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs-rebase RFC tpu Related to Google TPUs v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants