Skip to content

[V1][TPU] Speed up top-k on TPU by using torch.topk #15242

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 21, 2025

Conversation

hyeygit
Copy link
Contributor

@hyeygit hyeygit commented Mar 20, 2025

TL;DR

Using torch.topk leads to significant speed up for top-k on TPU.

Background / Baseline

#14227 added sampling support to TPU. Benchmark results (from #13982) showed that TPUv6 is about 38x slower than L4 GPU (507ms on TPU vs 13ms on L4 GPU for "Running 32 elapsed time").

Improvement for TPU

Existing vllm v1 code uses discrete torch operations to implement top-k and top-p (e.g. sort and mask), see apply_top_k_top_p. This is somehow inefficient on TPU. Turns out for top-k, pytorch already has an existing API torch.topk and a corresponding XLA lowering, and that seems to make things much faster.

New top-k TPU benchmark (w/ torch.topk)

With the changes in this PR, benchmark result suggests a 250x speed up for top-k on TPU:

# Top-k sampling benchmark on TPU v6e-1. Note that top-p is disabled (set to None).

INFO 03-17 21:23:58 [__init__.py:256] Automatically detected platform tpu.
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
Compiling/Warmup 1 elapsed time: 1.2743589878082275
Compiling/Warmup 4 elapsed time: 0.9751193523406982
Compiling/Warmup 16 elapsed time: 1.5950334072113037
Compiling/Warmup 32 elapsed time: 1.5763509273529053
Running 1 elapsed time: 0.002122640609741211
Running 1 elapsed time: 0.001752614974975586
Running 1 elapsed time: 0.0017757415771484375
Running 1 elapsed time: 0.0017747879028320312
Running 4 elapsed time: 0.0017614364624023438
Running 4 elapsed time: 0.0017540454864501953
Running 4 elapsed time: 0.0017423629760742188
Running 4 elapsed time: 0.0017385482788085938
Running 16 elapsed time: 0.0018427371978759766
Running 16 elapsed time: 0.0018305778503417969
Running 16 elapsed time: 0.0018429756164550781
Running 16 elapsed time: 0.0018262863159179688
Running 32 elapsed time: 0.0019214153289794922
Running 32 elapsed time: 0.0019321441650390625
Running 32 elapsed time: 0.0019261837005615234
Running 32 elapsed time: 0.001947641372680664

Notice the ~250x reduction in “Running 32 elapsed time” (1.9 ms vs 507 ms previously per #13982).

GPU benchmark with torch.topk

For completeness, I ran the same benchmark on H100 GPU, and observed similar runtime performance as on v6e-1:

# Top-k sampling benhmark on H100 GPU VM. Note that top-p is disabled (set to None).

INFO 03-18 20:45:15 [__init__.py:256] Automatically detected platform cuda.
WARNING 03-18 20:45:16 [topk_topp_sampler.py:63] FlashInfer is not available. Falling back to the PyTorch-native implementation of top-p & top-k sampling. For the best performance, please install FlashInfer.
Compiling/Warmup 1 elapsed time: 0.8188707828521729
Compiling/Warmup 4 elapsed time: 0.08226132392883301
Compiling/Warmup 16 elapsed time: 0.001528024673461914
Compiling/Warmup 32 elapsed time: 0.0015192031860351562
Running 1 elapsed time: 0.0011334419250488281
Running 1 elapsed time: 0.0010404586791992188
Running 1 elapsed time: 0.0010061264038085938
Running 1 elapsed time: 0.0009946823120117188
Running 4 elapsed time: 0.0011370182037353516
Running 4 elapsed time: 0.0011453628540039062
Running 4 elapsed time: 0.001129150390625
Running 4 elapsed time: 0.0011076927185058594
Running 16 elapsed time: 0.0011191368103027344
Running 16 elapsed time: 0.001134634017944336
Running 16 elapsed time: 0.001131296157836914
Running 16 elapsed time: 0.0011317729949951172
Running 32 elapsed time: 0.001157999038696289
Running 32 elapsed time: 0.0011749267578125
Running 32 elapsed time: 0.0011844635009765625
Running 32 elapsed time: 0.0011768341064453125

Notice the “Running 32 elapsed time” is 1.2 ms, in the same order of magnitude as TPU’s 1.9 ms.

Conclusion

  • Up to 250x speed-up for top-k was achieved for TPU by using torch.topk instead of vllm's discrete implementation.
  • TPU (v6e-1) and GPU (H100) are at performance parity when using torch.topk.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label Mar 20, 2025
@hyeygit hyeygit force-pushed the topk_tpu branch 2 times, most recently from 803f9cf to 6a32ed7 Compare March 20, 2025 18:55
@hyeygit hyeygit marked this pull request as ready for review March 20, 2025 19:14
@hyeygit
Copy link
Contributor Author

hyeygit commented Mar 20, 2025

cc @NickLucche if you'd like to take a look.

Copy link
Contributor

@yarongmu-google yarongmu-google left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks hyeygit :) Can you please add a test to guard this change from future regression?

Also:
* Added an env VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION to guard the
  optimization logic. This is to be conservative, so that we can turn
  off these changes in case they cause troubles.
* Edited tpu/test_sampler.py to include top-k.

Signed-off-by: Hyesoo Yang <[email protected]>
@hyeygit
Copy link
Contributor Author

hyeygit commented Mar 21, 2025

@@ -95,6 +95,7 @@
VLLM_DP_MASTER_PORT: int = 0
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
VLLM_V0_USE_OUTLINES_CACHE: bool = False
VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION: bool = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need this. Can you add a follow up PR to remove this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, will send a follow up PR to remove this once things look stable after merging.

Copy link
Collaborator

@robertgshaw2-redhat robertgshaw2-redhat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain why there is a need to have an option to disable this?

@robertgshaw2-redhat robertgshaw2-redhat enabled auto-merge (squash) March 21, 2025 00:55
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 21, 2025
@hyeygit
Copy link
Contributor Author

hyeygit commented Mar 21, 2025

Can you explain why there is a need to have an option to disable this?

Thanks for the review! This is a conservative measure -- in case there is any adverse effect from this change we can quickly disable it. This is meant to be temporary. Will send follow-up PR to remove once things look stable.

@vllm-bot vllm-bot merged commit 4719505 into vllm-project:main Mar 21, 2025
36 of 40 checks passed
@NickLucche
Copy link
Contributor

Thanks for the PR @hyeygit ! I did see that topk impl but I was (wrongly) afraid it would introduce dynamic shapes as the output has shape K, iirc.

As a general rule, I'd be in favor of testing PRs a bit more before merging rather than introducing temporary env variables to disable features.

@HAKSOAT
Copy link

HAKSOAT commented Mar 22, 2025

Hi @hyeygit, thanks for your work here.

I was just studying your PR and I had a question.

If the implementation says:

        if k is not None and p is None:
            # the implementation
        else:
            pass

Does this not mean that p would not have an effect on the sampler?

Or does it have an effect somehow that I am missing?

@njhill
Copy link
Member

njhill commented Mar 25, 2025

@hyeygit I'm curious whether this was tested at all? It looks completely wrong to me - torch.topk only takes a scalar and k here is a batch-size vector, which also may have entries equal to the vocab size (i.e. top k disabled). Also as @HAKSOAT points out in the current state top-p is ignored.

You can still use torch.topk with the max top k from the batch and gather the values from the result, and this is much faster than doing the whole sort - for GPU too. I have opened a PR for this #15478, perhaps the TPU-specific path is no longer needed after that (until there's a dedicated kernel)?

lulmer pushed a commit to lulmer/vllm that referenced this pull request Apr 7, 2025
nishith-fujitsu pushed a commit to nishith-fujitsu/vllm that referenced this pull request Apr 9, 2025
shreyankg pushed a commit to shreyankg/vllm that referenced this pull request May 3, 2025
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants