-
-
Notifications
You must be signed in to change notification settings - Fork 7.8k
[V1][TPU] Support V1 Sampler for ragged attention #14227
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[V1][TPU] Support V1 Sampler for ragged attention #14227
Conversation
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
55fbade
to
75440be
Compare
Also, sadly compilation takes forever now. Waiting for new kernels to address that. |
vllm/v1/worker/tpu_model_runner.py
Outdated
num_reqs_to_sample = MIN_NUM_SEQS | ||
while True: | ||
self._dummy_run(self.kv_caches, num_tokens) | ||
logger.info(" -- num_tokens: %d", num_tokens) | ||
xm.mark_step() | ||
xm.wait_device_ops() | ||
if num_tokens >= self.max_num_tokens: | ||
# Compile for different sampler outputs in | ||
# [MIN_NUM_SEQS, pad(max_num_reqs)] | ||
while True: | ||
self._dummy_run(self.kv_caches, num_tokens) | ||
logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens, | ||
num_reqs_to_sample) | ||
xm.mark_step() | ||
xm.wait_device_ops() | ||
if num_reqs_to_sample >= self.max_num_reqs: | ||
break | ||
num_reqs_to_sample *= 2 | ||
if num_tokens >= self.scheduler_config.max_num_batched_tokens: | ||
break | ||
num_tokens *= 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't seem right.. each time we are only passing in num_tokens
, so what is the point of num_reqs_to_sample
in this loop? Looking around it seems like num_tokens_to_sample
is missing to be passed in.
This potential explosion of compilation makes me feel a better solution would be to completely separate the graphs for model forward and sampling. This way we can just define a "max sampling batch size" that can be separated - so we end up with M+N compilations rather than M*N
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think I can do it in in M+N here as sampling will be dependent on both num_reqs_to_sample
and hidden_state
, so for each different hidden_state shape we still need to compile. But we will only have MN compilations of the sampling step which should be much lighter.
I did consider breaking the graph multiple times but iirc the main concern was the interaction with torch.compile.
Quickly looking into an issue with greedy sampling and the control flow introduced here #13587 . EDIT: pushed a new commit addressing it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@NickLucche Nice work with the sampler and the optimizations to improve compilations times. Left some comments and questions.
vllm/v1/sample/sampler.py
Outdated
# Apply mask using boolean indexing | ||
logits[~valid_token_mask] = -float('inf') | ||
# Apply mask using boolean indexing (xla friendly) | ||
logits.masked_fill_(~valid_token_mask, -float("inf")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how did you discover this optimization? :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I was inspecting the compilations with PT_XLA_DEBUG=2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we revert this or does it also help GPu performance?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point, I can add this in a separate PR if needs be now that the two samplers are separated.
tests/v1/tpu/test_sampler.py
Outdated
s = time() | ||
_ = llm.generate(prompts, sampling_params) | ||
run2 = time() - s | ||
assert run1 > run2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How much is run1 slower? maybe it will be safer to do something like run1 > run * 2 (else you may be in the noise area if does recompile)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
performance is back with the kernel update in #14310, so I'll bring back the 0.1 check :)
kv_caches, | ||
num_tokens: int, | ||
) -> None: | ||
@torch.no_grad() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why did you need torch.no_grad? I recall this was causing errors
This pull request has merge conflicts that must be resolved before it can be |
Thanks a lot for the review and comments @alexm-redhat ! |
c5a52a2
to
a56afc2
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome work building the structure for future sampling param enablement once we get performant kernels!
Eval smoke test looks good (and runs fast!)
Running generate_until requests: 100%|██████████████████████| 1319/1319 [01:44<00:00, 12.67it/s]
INFO:lm-eval:Output path not provided, skipping saving results aggregated
vllm (pretrained=Qwen/Qwen2-1.5B-Instruct,max_model_len=2048,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
|Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.5868|± |0.0136|
| | |strict-match | 5|exact_match|↑ |0.5815|± |0.0136|
Full log here:
|
This pull request has merge conflicts that must be resolved before it can be |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @NickLucche , thanks for your contribution, I found two places causing recompilation. Is is possible to resolve them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for chiming in this late, but I think we need a broader discussion on how to share the sampler between different hardware backends. In V0, we found that the sampler implementation got over-complicated partly because different hardware backends added some if
statements here and there just for their own purposes. I do really want to prevent it from happening this time.
Signed-off-by: NickLucche <[email protected]>
Head branch was pushed to by a user without write access
fc06dd0
to
50ef555
Compare
Signed-off-by: NickLucche <[email protected]>
Rebased and updated. @WoosukKwon I've separated the Sampler code for TPU into a different namespace so it doesn't make use of the CUDA path any longer. Let me know what you think. @yaochengji I've addressed the small recompilation issue that was happening due to a mismatch in pre-processing at compile time vs runtime. PS. I'd also like to postpone the torch.compile wrapping of the sampling graph, as it's currently unclear whether we get benefits off of it (compile time appears to go up). |
@NickLucche Thanks for your update! Previously when resolving the recompilation issue in logits processor, I noticed the compilation time is about 500ms. This is not acceptable in some use cases. (E.g. a typical requirement of TPOT is 100ms). Therefore, could you resolve all the recompilation issues? BTW, there's an ongoing task to ensure that there's no recompilation after warmup, #14580. |
This is a different recompilation happening in pre rather than post processing, and it is dependent on existing v1 code. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: NickLucche <[email protected]>
Thank you @NickLucche for this PR! Do you plan to merge it soon? I hope to send a follow-up PR to optimize top-k for TPU. |
@hyeygit Thanks for your optimization work, feel free to ping me when you have something ready |
Signed-off-by: NickLucche <[email protected]>
Signed-off-by: NickLucche <[email protected]>
Signed-off-by: NickLucche <[email protected]> Signed-off-by: Louis Ulmer <[email protected]>
Signed-off-by: NickLucche <[email protected]>
Signed-off-by: NickLucche <[email protected]>
Signed-off-by: NickLucche <[email protected]> Signed-off-by: Mu Huai <[email protected]>
Updated version of #13982 with ragged kernel attention.
All considerations described in previous PR attempt are still valid (in particular still aiming to get a
singleoptimized graph with model+sampling), the main difference is in how the number of sampled tokens should be handled now that it's not always equal to the batch size.I think there's fundamentally one trade-off to consider:
max_num_seqs
.model.sample
and eat the cost of compilation (as it happens on main now).I think there's also a third option which I tried to implement here:
do_sample_indices: [4, 10]=>padded_do_sample_indices: [4, 10, 0, 0]
Note:
compilation test is currently failing (re-running a request with different sampling params isn't as fast), will dig more into it.Execution just isn't as fast as it was in previous PR: compiling doesn't always drop times by >10x. This is happening on main too, so I don't think it has anything to do with this PR. As a result, I lowered the bar intest_sampler_compilation.py
. We can focus on optimizing performance later (this is anyway the first PR where we try to track it).Tested with
VLLM_USE_V1=1 python -m pytest -s tests/v1/tpu/test_sampler.py
.Update:
I broke up again the model + sampling graph, so now compilation stages are separated: