-
-
Notifications
You must be signed in to change notification settings - Fork 7.6k
[WIP]: DRY sampling #16695
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[WIP]: DRY sampling #16695
Conversation
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the contribution!
So the main difference is that in V1 sampling params are managed "continuously" as in they're allocated once and then overwritten as requests come through to avoid re-allocation costs.
Because of this extra_data
field it is unclear how to handle that optimally here. Do you have more context to share on discussions about DRY as to why this could/would not be integrated like penalties are?
if not hasattr(sampling_metadata, 'extra_data') or sampling_metadata.extra_data is None: | ||
# If no extra_data field exists or is None, cannot apply DRY | ||
return logits | ||
|
||
# Check if any request might have DRY enabled (basic check) | ||
# More robust check would involve iterating through extra_data first | ||
has_potential_dry = any( | ||
data and data.get('dry_multiplier', _DRY_DEFAULT_MULTIPLIER) > 0 | ||
for data in sampling_metadata.extra_data | ||
) | ||
if not has_potential_dry: | ||
return logits |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this check should ideally be moved to a gpu_input_batch property like penalties.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a longer discussion within the #bounty-program long thread in Slack - but I think Nick was saying he wanted it to be part of extra args. I'm not quite familiar with the codebase itself, but DRY essentially works like other penalties and has a few params. It'd be great if it's a first class sampler where things like dry_multiplier, etc are just passed into the engine like repetition_penalty is. I'm fairly confused about the passing of extra args myself since some part of it reads to me like it's not integrated fully yet? But the Slack thread and latest messages were what I was going by.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@0xymoro apologies for the confusion. As @NickLucche says, the state that only depends on the current set of requests and their parameters, should go in InputBatch
here (you can see the state associated with other sampling params there).
Then, add logic in the add_request
method to update this state based on your own request parameters that you can retrieve from request.sampling_params.extra_args
.
Logic should also be added to the remove_request()
, swap_states()
and condense()
methods to remove/reorder the requests within the preallocated state, and to _make_sampling_metadata()
to update the SamplingMetadata based on the current (changed) state in the input batch.
All of this is what #13360 aims to abstract/encapsulate into the LogitProcessor interface.
I think current implementation is quite inefficient, there's too many slow fors and data movement imo wrt the other sampling options that are supported. Would you be interested in editing the |
Thanks @0xymoro. Is this equivalent to the "new/optimized implementation" that you had mentioned in the slack thread? Like @NickLucche says, I think there should in any case be opportunities to vectorize at least some of the operations which should make a big difference when there's a number of requests using DRY at the same time. That's mostly orthogonal to the other changes mentioned above though, I think they could be tackled in either order. |
From Slack discussion - will wait until #13360 is more finalized before building on top of that, so will come back to this PR or create a new one rebased on top of when that is merged in |
New feature (work in progress, would like some more guidance): #8581
Implementation of logits processor for DRY sampling. The code is a port of an optimized version of DRY that I've been running for months now across trillions of tokens and is stable.
Some of the code was adapted from the original optimized DRY by Gemini 2.5 Pro. The logic should be the same but once we're able to hook it up and test it will see then.
Would love some help on routing the extra_args into here so the sampler can be exposed. Got lost in the code a bit there and it's not clear to me how it works currently with old & new V1.