-
-
Notifications
You must be signed in to change notification settings - Fork 7.8k
[V1][TPU] Speed up top-k on TPU by using torch.topk #15242
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
803f9cf
to
6a32ed7
Compare
cc @NickLucche if you'd like to take a look. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks hyeygit :) Can you please add a test to guard this change from future regression?
Also: * Added an env VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION to guard the optimization logic. This is to be conservative, so that we can turn off these changes in case they cause troubles. * Edited tpu/test_sampler.py to include top-k. Signed-off-by: Hyesoo Yang <[email protected]>
|
@@ -95,6 +95,7 @@ | |||
VLLM_DP_MASTER_PORT: int = 0 | |||
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False | |||
VLLM_V0_USE_OUTLINES_CACHE: bool = False | |||
VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION: bool = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need this. Can you add a follow up PR to remove this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, will send a follow up PR to remove this once things look stable after merging.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain why there is a need to have an option to disable this?
Thanks for the review! This is a conservative measure -- in case there is any adverse effect from this change we can quickly disable it. This is meant to be temporary. Will send follow-up PR to remove once things look stable. |
Thanks for the PR @hyeygit ! I did see that topk impl but I was (wrongly) afraid it would introduce dynamic shapes as the output has shape K, iirc. As a general rule, I'd be in favor of testing PRs a bit more before merging rather than introducing temporary env variables to disable features. |
Hi @hyeygit, thanks for your work here. I was just studying your PR and I had a question. If the implementation says:
Does this not mean that p would not have an effect on the sampler? Or does it have an effect somehow that I am missing? |
@hyeygit I'm curious whether this was tested at all? It looks completely wrong to me - torch.topk only takes a scalar and k here is a batch-size vector, which also may have entries equal to the vocab size (i.e. top k disabled). Also as @HAKSOAT points out in the current state top-p is ignored. You can still use torch.topk with the max top k from the batch and gather the values from the result, and this is much faster than doing the whole sort - for GPU too. I have opened a PR for this #15478, perhaps the TPU-specific path is no longer needed after that (until there's a dedicated kernel)? |
Signed-off-by: Hyesoo Yang <[email protected]>
Signed-off-by: Hyesoo Yang <[email protected]> Signed-off-by: Louis Ulmer <[email protected]>
Signed-off-by: Hyesoo Yang <[email protected]>
Signed-off-by: Hyesoo Yang <[email protected]>
Signed-off-by: Hyesoo Yang <[email protected]> Signed-off-by: Mu Huai <[email protected]>
TL;DR
Using
torch.topk
leads to significant speed up for top-k on TPU.Background / Baseline
#14227 added sampling support to TPU. Benchmark results (from #13982) showed that TPUv6 is about 38x slower than L4 GPU (507ms on TPU vs 13ms on L4 GPU for "Running 32 elapsed time").
Improvement for TPU
Existing vllm v1 code uses discrete torch operations to implement top-k and top-p (e.g. sort and mask), see apply_top_k_top_p. This is somehow inefficient on TPU. Turns out for top-k, pytorch already has an existing API
torch.topk
and a corresponding XLA lowering, and that seems to make things much faster.New top-k TPU benchmark (w/
torch.topk
)With the changes in this PR, benchmark result suggests a 250x speed up for top-k on TPU:
Notice the ~250x reduction in “Running 32 elapsed time” (1.9 ms vs 507 ms previously per #13982).
GPU benchmark with
torch.topk
For completeness, I ran the same benchmark on H100 GPU, and observed similar runtime performance as on v6e-1:
Notice the “Running 32 elapsed time” is 1.2 ms, in the same order of magnitude as TPU’s 1.9 ms.
Conclusion
torch.topk
instead of vllm's discrete implementation.torch.topk
.