Skip to content

[TPU][V1] Enable Top-P #16843

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 22, 2025
Merged

[TPU][V1] Enable Top-P #16843

merged 4 commits into from
Apr 22, 2025

Conversation

NickLucche
Copy link
Contributor

@NickLucche NickLucche commented Apr 18, 2025

Follow up to #15489.
This PR enables the work done #15736.
We split up the contribs to better monitor performance as described more in detail here #16268.

Here's the benchmark I ran on single chip with and without top-p:

Benchmark commands
MODEL=meta-llama/Llama-3.1-8B-Instruct
# Server
VLLM_XLA_CHECK_RECOMPILATION=1 VLLM_USE_V1=1 vllm serve $MODEL \
 --disable-log-requests \
 --port 8004 \
 --gpu-memory-utilization 0.95 \
 --max-num-seqs 512 \
 --max-num-batched-tokens 512 \
 --tensor-parallel-size 1 \
 --max-model-len 2048 > "$VLLM_LOG" 2>&1 &
# Client with the newly added --top-k --top-p options
python benchmarks/benchmark_serving.py \
    --backend vllm \
    --model $MODEL \
    --dataset-name random \
    --random-input-len 1720 \
    --random-output-len 128 \
    --random-prefix-len 0 \
    --ignore-eos \
    --top-k 8 --top-p 0.5 \
    --port 8004 > "$BM_LOG"
# with top-p (this PR) 

 ============ Serving Benchmark Result ============
Successful requests:                     1000      
Benchmark duration (s):                  128.28    
Total input tokens:                      1720000   
Total generated tokens:                  128000    
Request throughput (req/s):              7.80      
Output token throughput (tok/s):         997.85    
Total Token throughput (tok/s):          14406.52  
---------------Time to First Token----------------
Mean TTFT (ms):                          63332.60  
Median TTFT (ms):                        62792.18  
P99 TTFT (ms):                           124350.61 
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          32.65     
Median TPOT (ms):                        33.30     
P99 TPOT (ms):                           63.55     
---------------Inter-token Latency----------------
Mean ITL (ms):                           37.65     
Median ITL (ms):                         33.21     
P99 ITL (ms):                            34.56     
==================================================
# without top-p (26507f897) 

============ Serving Benchmark Result ============
Successful requests:                     1000      
Benchmark duration (s):                  129.28    
Total input tokens:                      1720000   
Total generated tokens:                  128000    
Request throughput (req/s):              7.74      
Output token throughput (tok/s):         990.13    
Total Token throughput (tok/s):          14295.02  
---------------Time to First Token----------------
Mean TTFT (ms):                          64067.78  
Median TTFT (ms):                        63521.31  
P99 TTFT (ms):                           125308.93 
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          32.56     
Median TPOT (ms):                        33.42     
P99 TPOT (ms):                           63.77     
---------------Inter-token Latency----------------
Mean ITL (ms):                           37.30     
Median ITL (ms):                         33.33     
P99 ITL (ms):                            34.65     
==================================================

Performance is looking good here. We should probably run a larger sweep just to be safe.

PS: performance with/without topk/p sampling-params (eg setting only temperature=0.2) is virtually identical as the graph being executed is the same. This is not the case when all requests in the batch have temperature=0 (all_greedy).

cc @hyeygit

Signed-off-by: NickLucche <[email protected]>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added v1 tpu Related to Google TPUs labels Apr 18, 2025
Signed-off-by: NickLucche <[email protected]>
@NickLucche NickLucche changed the title enable topp [TPU][V1] Enable Top-P Apr 18, 2025
Signed-off-by: NickLucche <[email protected]>
@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Apr 18, 2025
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, LGTM

@DarkLight1337
Copy link
Member

Can you merge from main to fix CI?

Copy link
Collaborator

@yaochengji yaochengji left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the contribution!

@mgoin mgoin enabled auto-merge (squash) April 21, 2025 22:36
@mgoin mgoin merged commit fa3bba2 into vllm-project:main Apr 22, 2025
43 checks passed
dtransposed pushed a commit to dtransposed/vllm that referenced this pull request Apr 22, 2025
Signed-off-by: NickLucche <[email protected]>
Co-authored-by: mgoin <[email protected]>
Signed-off-by: dtransposed <damian@damian-ml-machine.europe-west3-b.c.jetbrains-grazie.internal>
frieda-huang pushed a commit to frieda-huang/vllm that referenced this pull request Apr 23, 2025
Signed-off-by: NickLucche <[email protected]>
Co-authored-by: mgoin <[email protected]>
Signed-off-by: Frieda (Jingying) Huang <[email protected]>
liuzijing2014 pushed a commit to liuzijing2014/vllm that referenced this pull request Apr 25, 2025
Signed-off-by: NickLucche <[email protected]>
Co-authored-by: mgoin <[email protected]>
Signed-off-by: Zijing Liu <[email protected]>
liuzijing2014 pushed a commit to liuzijing2014/vllm that referenced this pull request Apr 25, 2025
wuisawesome pushed a commit to character-tech/vllm that referenced this pull request Apr 28, 2025
jikunshang pushed a commit to jikunshang/vllm that referenced this pull request Apr 29, 2025
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Apr 29, 2025
adobrzyn pushed a commit to HabanaAI/vllm-fork that referenced this pull request Apr 30, 2025
Signed-off-by: NickLucche <[email protected]>
Co-authored-by: mgoin <[email protected]>
Signed-off-by: Agata Dobrzyniewicz <[email protected]>
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
Signed-off-by: NickLucche <[email protected]>
Co-authored-by: mgoin <[email protected]>
Signed-off-by: Mu Huai <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed tpu Related to Google TPUs v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants