Skip to content

[Kernel] GGUF MoeVec kernel #16780

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
May 7, 2025
Merged

[Kernel] GGUF MoeVec kernel #16780

merged 9 commits into from
May 7, 2025

Conversation

SzymonOzog
Copy link
Contributor

When we don't have a high expert utilisation this kernel will work much faster than matmul style moe kernel. Also adds better support for I quants

SzymonOzog and others added 4 commits April 17, 2025 08:34
Signed-off-by: SzymonOzog <[email protected]>
Signed-off-by: SzymonOzog <[email protected]>
Signed-off-by: SzymonOzog <[email protected]>
Signed-off-by: SzymonOzog <[email protected]>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mgoin mgoin requested a review from Isotr0py April 17, 2025 12:54
@Isotr0py Isotr0py self-assigned this Apr 17, 2025
Copy link
Collaborator

@Isotr0py Isotr0py left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delay. This looks reasonable to me once we updated the GGUF kernel tests to cover the MoeVec kernel!

Comment on lines +381 to +384
torch::Tensor ggml_moe_a8_vec(torch::Tensor X, // input
torch::Tensor W, // expert weights
torch::Tensor topk_ids, int64_t top_k,
int64_t type, int64_t row, int64_t tokens) {
Copy link
Collaborator

@Isotr0py Isotr0py Apr 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also update the GGUF kernel tests to cover I-Quants with MoeVec kernel?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. Updated with I-Quants

@zhaotyer
Copy link
Contributor

zhaotyer commented Apr 22, 2025

When we don't have a high expert utilisation this kernel will work much faster than matmul style moe kernel. Also adds better support for I quants

when I set max_model_len to 8192, The service will crash when it start
I tested it on 2xA100x80GB and 8xL40Sx45GB, and both showed errors.

vllm serve /models/DeepSeek-R1-UD-IQ1_S/merged_file.gguf -tp 8 --trust-remote-code --trust-remote-code --tokenizer /models/DeepSeek-R1-UD-IQ1_S/ --hf-config-path /models/DeepSeek-R1-UD-IQ1_S/ --dtype bfloat16 --max-model-len 32768 --served-model-name atom --port 8160 --gpu-memory-utilization 0.95 --enable-chunked-prefill --max-num-batched-tokens 512

error log

INFO 04-22 04:13:35 [model_runner.py:1146] Model loading took 66.5477 GiB and 216.481170 seconds
ERROR 04-22 04:13:40 [engine.py:448] CUDA error: invalid configuration argument
ERROR 04-22 04:13:40 [engine.py:448] CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
ERROR 04-22 04:13:40 [engine.py:448] For debugging consider passing CUDA_LAUNCH_BLOCKING=1
ERROR 04-22 04:13:40 [engine.py:448] Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
ERROR 04-22 04:13:40 [engine.py:448] Traceback (most recent call last):
ERROR 04-22 04:13:40 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/engine/multiprocessing/engine.py", line 436, in run_mp_engine
ERROR 04-22 04:13:40 [engine.py:448]     engine = MQLLMEngine.from_vllm_config(
ERROR 04-22 04:13:40 [engine.py:448]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-22 04:13:40 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/engine/multiprocessing/engine.py", line 128, in from_vllm_config
ERROR 04-22 04:13:40 [engine.py:448]     return cls(
ERROR 04-22 04:13:40 [engine.py:448]            ^^^^
ERROR 04-22 04:13:40 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/engine/multiprocessing/engine.py", line 82, in __init__
ERROR 04-22 04:13:40 [engine.py:448]     self.engine = LLMEngine(*args, **kwargs)
ERROR 04-22 04:13:40 [engine.py:448]                   ^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-22 04:13:40 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/engine/llm_engine.py", line 285, in __init__
ERROR 04-22 04:13:40 [engine.py:448]     self._initialize_kv_caches()
ERROR 04-22 04:13:40 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/engine/llm_engine.py", line 434, in _initialize_kv_caches
ERROR 04-22 04:13:40 [engine.py:448]     self.model_executor.determine_num_available_blocks())
ERROR 04-22 04:13:40 [engine.py:448]     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-22 04:13:40 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/executor/executor_base.py", line 103, in determine_num_available_blocks
ERROR 04-22 04:13:40 [engine.py:448]     results = self.collective_rpc("determine_num_available_blocks")
ERROR 04-22 04:13:40 [engine.py:448]               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-22 04:13:40 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/executor/executor_base.py", line 331, in collective_rpc
ERROR 04-22 04:13:40 [engine.py:448]     return self._run_workers(method, *args, **(kwargs or {}))
ERROR 04-22 04:13:40 [engine.py:448]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-22 04:13:40 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/executor/mp_distributed_executor.py", line 185, in _run_workers
ERROR 04-22 04:13:40 [engine.py:448]     driver_worker_output = run_method(self.driver_worker, sent_method,
ERROR 04-22 04:13:40 [engine.py:448]                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-22 04:13:40 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/utils.py", line 2428, in run_method
ERROR 04-22 04:13:40 [engine.py:448]     return func(*args, **kwargs)
ERROR 04-22 04:13:40 [engine.py:448]            ^^^^^^^^^^^^^^^^^^^^^
ERROR 04-22 04:13:40 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
ERROR 04-22 04:13:40 [engine.py:448]     return func(*args, **kwargs)
ERROR 04-22 04:13:40 [engine.py:448]            ^^^^^^^^^^^^^^^^^^^^^
ERROR 04-22 04:13:40 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/worker.py", line 229, in determine_num_available_blocks
ERROR 04-22 04:13:40 [engine.py:448]     self.model_runner.profile_run()
ERROR 04-22 04:13:40 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
ERROR 04-22 04:13:40 [engine.py:448]     return func(*args, **kwargs)
ERROR 04-22 04:13:40 [engine.py:448]            ^^^^^^^^^^^^^^^^^^^^^
ERROR 04-22 04:13:40 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1243, in profile_run
ERROR 04-22 04:13:40 [engine.py:448]     self._dummy_run(max_num_batched_tokens, max_num_seqs)
ERROR 04-22 04:13:40 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1369, in _dummy_run
ERROR 04-22 04:13:40 [engine.py:448]     self.execute_model(model_input, kv_caches, intermediate_tensors)
ERROR 04-22 04:13:40 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
ERROR 04-22 04:13:40 [engine.py:448]     return func(*args, **kwargs)
ERROR 04-22 04:13:40 [engine.py:448]            ^^^^^^^^^^^^^^^^^^^^^
ERROR 04-22 04:13:40 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1770, in execute_model
ERROR 04-22 04:13:40 [engine.py:448]     hidden_or_intermediate_states = model_executable(
ERROR 04-22 04:13:40 [engine.py:448]                                     ^^^^^^^^^^^^^^^^^
ERROR 04-22 04:13:40 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
ERROR 04-22 04:13:40 [engine.py:448]     return self._call_impl(*args, **kwargs)
ERROR 04-22 04:13:40 [engine.py:448]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-22 04:13:40 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
ERROR 04-22 04:13:40 [engine.py:448]     return forward_call(*args, **kwargs)
ERROR 04-22 04:13:40 [engine.py:448]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-22 04:13:40 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 703, in forward
ERROR 04-22 04:13:40 [engine.py:448]     hidden_states = self.model(input_ids, positions, intermediate_tensors,
ERROR 04-22 04:13:40 [engine.py:448]                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-22 04:13:40 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/compilation/decorators.py", line 172, in __call__
ERROR 04-22 04:13:40 [engine.py:448]     return self.forward(*args, **kwargs)
ERROR 04-22 04:13:40 [engine.py:448]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-22 04:13:40 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 660, in forward
ERROR 04-22 04:13:40 [engine.py:448]     hidden_states, residual = layer(positions, hidden_states, residual)
ERROR 04-22 04:13:40 [engine.py:448]                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-22 04:13:40 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
ERROR 04-22 04:13:40 [engine.py:448]     return self._call_impl(*args, **kwargs)
ERROR 04-22 04:13:40 [engine.py:448]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-22 04:13:40 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
ERROR 04-22 04:13:40 [engine.py:448]     return forward_call(*args, **kwargs)
ERROR 04-22 04:13:40 [engine.py:448]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-22 04:13:40 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 580, in forward
ERROR 04-22 04:13:40 [engine.py:448]     hidden_states = self.mlp(hidden_states)
ERROR 04-22 04:13:40 [engine.py:448]                     ^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-22 04:13:40 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
ERROR 04-22 04:13:40 [engine.py:448]     return self._call_impl(*args, **kwargs)
ERROR 04-22 04:13:40 [engine.py:448]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-22 04:13:40 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
ERROR 04-22 04:13:40 [engine.py:448]     return forward_call(*args, **kwargs)
ERROR 04-22 04:13:40 [engine.py:448]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-22 04:13:40 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 159, in forward
ERROR 04-22 04:13:40 [engine.py:448]     final_hidden_states = self.experts(
ERROR 04-22 04:13:40 [engine.py:448]                           ^^^^^^^^^^^^^
ERROR 04-22 04:13:40 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
ERROR 04-22 04:13:40 [engine.py:448]     return self._call_impl(*args, **kwargs)
ERROR 04-22 04:13:40 [engine.py:448]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-22 04:13:40 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
ERROR 04-22 04:13:40 [engine.py:448]     return forward_call(*args, **kwargs)
ERROR 04-22 04:13:40 [engine.py:448]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-22 04:13:40 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/layer.py", line 842, in forward
ERROR 04-22 04:13:40 [engine.py:448]     return self.forward_impl(hidden_states, router_logits)
ERROR 04-22 04:13:40 [engine.py:448]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-22 04:13:40 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/layer.py", line 861, in forward_impl
ERROR 04-22 04:13:40 [engine.py:448]     final_hidden_states = self.quant_method.apply(
ERROR 04-22 04:13:40 [engine.py:448]                           ^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-22 04:13:40 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/quantization/gguf.py", line 377, in apply
ERROR 04-22 04:13:40 [engine.py:448]     return _fused_moe_gguf(x, layer.w13_qweight, layer.w2_qweight,
ERROR 04-22 04:13:40 [engine.py:448]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-22 04:13:40 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/quantization/gguf.py", line 176, in _fused_moe_gguf
ERROR 04-22 04:13:40 [engine.py:448]     out = ops.ggml_moe_a8_vec(out, w2, topk_ids, 1, qweight_type2,
ERROR 04-22 04:13:40 [engine.py:448]           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-22 04:13:40 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/_custom_ops.py", line 1179, in ggml_moe_a8_vec
ERROR 04-22 04:13:40 [engine.py:448]     return torch.ops._C.ggml_moe_a8_vec(X, W, topk_ids, top_k, quant_type, row,
ERROR 04-22 04:13:40 [engine.py:448]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-22 04:13:40 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 1123, in __call__
ERROR 04-22 04:13:40 [engine.py:448]     return self._op(*args, **(kwargs or {}))
ERROR 04-22 04:13:40 [engine.py:448]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-22 04:13:40 [engine.py:448] RuntimeError: CUDA error: invalid configuration argument
ERROR 04-22 04:13:40 [engine.py:448] CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
ERROR 04-22 04:13:40 [engine.py:448] For debugging consider passing CUDA_LAUNCH_BLOCKING=1
ERROR 04-22 04:13:40 [engine.py:448] Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

@SzymonOzog
Copy link
Contributor Author

@zhaotyer

Could you try with --enable-chunked-prefill --max-num-batched-tokens 512
I think I know what's causing the issue, it's also present on main, we would need an option to run kernel in chunks similarly to how it's done in triton. I'll add that in a following PR

@zhaotyer
Copy link
Contributor

@zhaotyer

Could you try with --enable-chunked-prefill --max-num-batched-tokens 512 I think I know what's causing the issue, it's also present on main, we would need an option to run kernel in chunks similarly to how it's done in triton. I'll add that in a following PR

It's work well,Below are the benchmark results on 8L40s45GB
Start command

vllm serve /models/DeepSeek-R1-UD-IQ1_S/merge.gguf -tp 8 --trust-remote-code --trust-remote-code --tokenizer /models/DeepSeek-R1-UD-IQ1_S/ --hf-config-path /models/DeepSeek-R1-UD-IQ1_S/ --dtype bfloat16 --max-model-len 32768 --served-model-name atom --port 8160 --gpu-memory-utilization 0.85  --enable-chunked-prefill --max-num-batched-tokens 512

Benchmark result:

root@llm16:/llm_benchmark/vllm_benchmark# ./run_benshmark.sh 
==============================================
 Running benchmarks for INPUT_LEN=1024, OUTPUT_LEN=1024
==============================================

Running benchmark with BATCH_SIZE=1
Namespace(backend='openai-chat', base_url=None, host='127.0.0.1', port=8160, endpoint='/v1/chat/completions', dataset=None, dataset_name='random', dataset_path='../data/sonnet.txt', model='atom', tokenizer='../tokenizer', best_of=1, use_beam_search=False, num_prompts=1, sharegpt_output_len=None, sonnet_input_len=1024, sonnet_output_len=1024, sonnet_prefix_len=30, random_input_len=1024, random_output_len=1024, random_range_ratio=1.0, request_rate=inf, seed=0, trust_remote_code=False, disable_tqdm=False, save_result=False, metadata=None, result_dir=None, result_filename=None)
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████���██████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:27<00:00, 27.54s/it]
============ Serving Benchmark Result ============
Successful requests:                     1         
Benchmark duration (s):                  27.54     
Total input tokens:                      1024      
Total generated tokens:                  1005      
Request throughput (req/s):              0.04      
Input token throughput (tok/s):          37.18     
Output token throughput (tok/s):         36.49     
---------------Time to First Token----------------
Mean TTFT (ms):                          4890.51   
Median TTFT (ms):                        4890.51   
P99 TTFT (ms):                           4890.51   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          22.56     
Median TPOT (ms):                        22.56     
P99 TPOT (ms):                           22.56     
---------------Inter-token Latency----------------
Mean ITL (ms):                           22.14     
Median ITL (ms):                         22.14     
P99 ITL (ms):                            23.31     
==================================================

Running benchmark with BATCH_SIZE=2
Namespace(backend='openai-chat', base_url=None, host='127.0.0.1', port=8160, endpoint='/v1/chat/completions', dataset=None, dataset_name='random', dataset_path='../data/sonnet.txt', model='atom', tokenizer='../tokenizer', best_of=1, use_beam_search=False, num_prompts=2, sharegpt_output_len=None, sonnet_input_len=1024, sonnet_output_len=1024, sonnet_prefix_len=30, random_input_len=1024, random_output_len=1024, random_range_ratio=1.0, request_rate=inf, seed=0, trust_remote_code=False, disable_tqdm=False, save_result=False, metadata=None, result_dir=None, result_filename=None)
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:48<00:00, 24.31s/it]
============ Serving Benchmark Result ============
Successful requests:                     2         
Benchmark duration (s):                  48.63     
Total input tokens:                      2048      
Total generated tokens:                  1955      
Request throughput (req/s):              0.04      
Input token throughput (tok/s):          42.12     
Output token throughput (tok/s):         40.20     
---------------Time to First Token----------------
Mean TTFT (ms):                          9883.16   
Median TTFT (ms):                        9883.16   
P99 TTFT (ms):                           10907.44  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          39.64     
Median TPOT (ms):                        39.64     
P99 TPOT (ms):                           39.85     
---------------Inter-token Latency----------------
Mean ITL (ms):                           37.83     
Median ITL (ms):                         36.89     
P99 ITL (ms):                            38.44     
==================================================


Running benchmark with BATCH_SIZE=4

Namespace(backend='openai-chat', base_url=None, host='127.0.0.1', port=8160, endpoint='/v1/chat/completions', dataset=None, dataset_name='random', dataset_path='../data/sonnet.txt', model='atom', tokenizer='../tokenizer', best_of=1, use_beam_search=False, num_prompts=4, sharegpt_output_len=None, sonnet_input_len=1024, sonnet_output_len=1024, sonnet_prefix_len=30, random_input_len=1024, random_output_len=1024, random_range_ratio=1.0, request_rate=inf, seed=0, trust_remote_code=False, disable_tqdm=False, save_result=False, metadata=None, result_dir=None, result_filename=None)
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [01:04<00:00, 16.14s/it]
============ Serving Benchmark Result ============
Successful requests:                     4         
Benchmark duration (s):                  64.58     
Total input tokens:                      4096      
Total generated tokens:                  3804      
Request throughput (req/s):              0.06      
Input token throughput (tok/s):          63.43     
Output token throughput (tok/s):         58.91     
---------------Time to First Token----------------
Mean TTFT (ms):                          14266.67  
Median TTFT (ms):                        14534.42  
P99 TTFT (ms):                           20636.03  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          53.21     
Median TPOT (ms):                        50.28     
P99 TPOT (ms):                           65.93     
---------------Inter-token Latency----------------
Mean ITL (ms):                           49.03     
Median ITL (ms):                         42.96     
P99 ITL (ms):                            44.96     
==================================================

Running benchmark with BATCH_SIZE=8
Namespace(backend='openai-chat', base_url=None, host='127.0.0.1', port=8160, endpoint='/v1/chat/completions', dataset=None, dataset_name='random', dataset_path='../data/sonnet.txt', model='atom', tokenizer='../tokenizer', best_of=1, use_beam_search=False, num_prompts=8, sharegpt_output_len=None, sonnet_input_len=1024, sonnet_output_len=1024, sonnet_prefix_len=30, random_input_len=1024, random_output_len=1024, random_range_ratio=1.0, request_rate=inf, seed=0, trust_remote_code=False, disable_tqdm=False, save_result=False, metadata=None, result_dir=None, result_filename=None)
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [01:32<00:00, 11.50s/it]
============ Serving Benchmark Result ============
Successful requests:                     8         
Benchmark duration (s):                  92.00     
Total input tokens:                      8192      
Total generated tokens:                  5667      
Request throughput (req/s):              0.09      
Input token throughput (tok/s):          89.04     
Output token throughput (tok/s):         61.60     
---------------Time to First Token----------------
Mean TTFT (ms):                          19518.43  
Median TTFT (ms):                        18920.55  
P99 TTFT (ms):                           38009.99  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          72.88     
Median TPOT (ms):                        73.68     
P99 TPOT (ms):                           97.07     
---------------Inter-token Latency----------------
Mean ITL (ms):                           66.69     
Median ITL (ms):                         52.78     
P99 ITL (ms):                            443.65    
==================================================

Running benchmark with BATCH_SIZE=16
Namespace(backend='openai-chat', base_url=None, host='127.0.0.1', port=8160, endpoint='/v1/chat/completions', dataset=None, dataset_name='random', dataset_path='../data/sonnet.txt', model='atom', tokenizer='../tokenizer', best_of=1, use_beam_search=False, num_prompts=16, sharegpt_output_len=None, sonnet_input_len=1024, sonnet_output_len=1024, sonnet_prefix_len=30, random_input_len=1024, random_output_len=1024, random_range_ratio=1.0, request_rate=inf, seed=0, trust_remote_code=False, disable_tqdm=False, save_result=False, metadata=None, result_dir=None, result_filename=None)
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [02:36<00:00,  9.79s/it]
============ Serving Benchmark Result ============
Successful requests:                     16        
Benchmark duration (s):                  156.67    
Total input tokens:                      16384     
Total generated tokens:                  12957     
Request throughput (req/s):              0.10      
Input token throughput (tok/s):          104.58    
Output token throughput (tok/s):         82.70     
---------------Time to First Token----------------
Mean TTFT (ms):                          43333.50  
Median TTFT (ms):                        43046.04  
P99 TTFT (ms):                           78300.02  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          135.15    
Median TPOT (ms):                        132.15    
P99 TPOT (ms):                           229.35    
---------------Inter-token Latency----------------
Mean ITL (ms):                           111.62    
Median ITL (ms):                         77.29     
P99 ITL (ms):                            1453.83   
==================================================

Running benchmark with BATCH_SIZE=32
Namespace(backend='openai-chat', base_url=None, host='127.0.0.1', port=8160, endpoint='/v1/chat/completions', dataset=None, dataset_name='random', dataset_path='../data/sonnet.txt', model='atom', tokenizer='../tokenizer', best_of=1, use_beam_search=False, num_prompts=32, sharegpt_output_len=None, sonnet_input_len=1024, sonnet_output_len=1024, sonnet_prefix_len=30, random_input_len=1024, random_output_len=1024, random_range_ratio=1.0, request_rate=inf, seed=0, trust_remote_code=False, disable_tqdm=False, save_result=False, metadata=None, result_dir=None, result_filename=None)
  6%|█████████████▌                                                                                                                                                                                                          | 2/32 [01:50<30:24, 60.82s/it]

  9%|████████████████████▎                                                                                                                                                                                                   | 3/32 [03:11<33:42, 69.73s/it]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [04:40<00:00,  8.76s/it]
============ Serving Benchmark Result ============
Successful requests:                     32        
Benchmark duration (s):                  280.46    
Total input tokens:                      32768     
Total generated tokens:                  26877     
Request throughput (req/s):              0.11      
Input token throughput (tok/s):          116.84    
Output token throughput (tok/s):         95.83     
---------------Time to First Token----------------
Mean TTFT (ms):                          82584.55  
Median TTFT (ms):                        82371.42  
P99 TTFT (ms):                           155532.25 
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          360.14    
Median TPOT (ms):                        203.56    
P99 TPOT (ms):                           2685.83   
---------------Inter-token Latency----------------
Mean ITL (ms):                           194.41    
Median ITL (ms):                         126.07    
P99 ITL (ms):                            1471.76   
==================================================

@SzymonOzog Very fast, thanks for your work

Signed-off-by: SzymonOzog <[email protected]>
Copy link

mergify bot commented Apr 23, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @SzymonOzog.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Apr 23, 2025
Signed-off-by: SzymonOzog <[email protected]>
@mergify mergify bot removed the needs-rebase label Apr 23, 2025
@Isotr0py Isotr0py added the ready ONLY add when PR is ready to merge/full CI is needed label Apr 23, 2025
@DarkLight1337
Copy link
Member

Can you merge from main to fix docker build?

Signed-off-by: SzymonOzog <[email protected]>
@SzymonOzog
Copy link
Contributor Author

@DarkLight1337 Merged main

@DarkLight1337
Copy link
Member

PTAL at the failing installation test. It seems related to this PR

@SzymonOzog
Copy link
Contributor Author

@DarkLight1337 So the test seems to use precompiled nightly wheen where the kernel from this PR is not yet present that's why @register_fake is failing, can I do something to include it in the nightly wheel pre merge?

@DarkLight1337
Copy link
Member

@khluu @mgoin @LucasWilkinson any ideas?

@Isotr0py
Copy link
Collaborator

Isotr0py commented May 4, 2025

@SzymonOzog Can you merge from main to see if the python-only-installation test still fails?

Signed-off-by: SzymonOzog <[email protected]>
@SzymonOzog
Copy link
Contributor Author

Updated to main

@Isotr0py Isotr0py enabled auto-merge (squash) May 6, 2025 14:12
@vllm-bot vllm-bot merged commit 1a45a61 into vllm-project:main May 7, 2025
77 of 81 checks passed
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
Signed-off-by: SzymonOzog <[email protected]>
Signed-off-by: SzymonOzog <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Co-authored-by: Isotr0py <[email protected]>
Signed-off-by: Mu Huai <[email protected]>
mawong-amd pushed a commit to ROCm/vllm that referenced this pull request May 14, 2025
Signed-off-by: SzymonOzog <[email protected]>
Signed-off-by: SzymonOzog <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Co-authored-by: Isotr0py <[email protected]>
zzzyq pushed a commit to zzzyq/vllm that referenced this pull request May 24, 2025
Signed-off-by: SzymonOzog <[email protected]>
Signed-off-by: SzymonOzog <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Co-authored-by: Isotr0py <[email protected]>
Signed-off-by: Yuqi Zhang <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants