Skip to content

[Kernel] DeepEP dispatch-combine kernel integration #18434

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 3, 2025

Conversation

varun-sundar-rabindranath
Copy link
Contributor

@varun-sundar-rabindranath varun-sundar-rabindranath commented May 20, 2025

Integrate DeepEP dispatch-combine kernels

  • Integrated DeepEP high-throughput and low-latency kernels
  • Integrate DeepEP high-throughput kernel with the corresponding DeepGemm kernel

Correctness:

  • Tested correctness using lm_eval on H100 for,
    Models: deepseek-ai/DeepSeek-V2-Lite RedHatAI/DeepSeek-Coder-V2-Lite-Instruct-FP8 Qwen/Qwen3-30B-A3B-FP8
    ALL2ALL Backend: deepep_high_throughput
    Cases: for DP=2 TP=1 case.
  • Tested correctness using lm_eval on H100 for,
    Models: deepseek-ai/DeepSeek-V2-Lite RedHatAI/DeepSeek-Coder-V2-Lite-Instruct-FP8
    ALL2ALL Backend: deepep_low_latency
    Cases: for DP=2 TP=1 case.
    Note: DeepEP Low Latency kernels are compiled only for a set of hidden-sizes. DeepSeekV2-lite hidden sizes are not among them. I had to update the DeepEP to support the hidden size to do this test.
  • Tested correctness using lm_eval on A100 for,
    Models: deepseek-ai/DeepSeek-V2-Lite
    ALL2ALL Backend: pplx
    Cases: DP=2 TP=1

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@varun-sundar-rabindranath varun-sundar-rabindranath marked this pull request as draft May 20, 2025 20:00
Copy link

mergify bot commented May 23, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @varun-sundar-rabindranath.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Copy link

mergify bot commented May 29, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @varun-sundar-rabindranath.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label May 29, 2025
Comment on lines 462 to 465
# TODO (varun) : deepgemm integration
self.use_batched_experts = False
if envs.VLLM_ALL2ALL_BACKEND == "deepep_ll":
self.use_batched_experts = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be better to add a method to prepare_finalize that says whether or not the format is batched or not, instead of using an env var or checking the instance type.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

introduced a function max_num_tokens_per_rank() to the prepare_finalize objects - We can use it to determine batching 👍

@mergify mergify bot removed the needs-rebase label May 30, 2025
@varun-sundar-rabindranath varun-sundar-rabindranath marked this pull request as ready for review May 30, 2025 21:57
Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. I just left a few minor comments. Going to try this out in a multinode setup

# weights have already been applied.
combine_topk_weights = torch.ones_like(topk_weights)

# TODO (varun) : Enable zero copy mode
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still TODO?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. It could be fast-follow.

and _valid_deep_gemm(hidden_states, w1, w2, expert_map)):
return self.deep_gemm_expert.apply(
and _valid_deep_gemm(hidden_states, w1, w2)):
return self.deep_gemm_expert.apply( #type: ignore
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need the #type: ignore? Could you add a comment (or better: resolve the type issues)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deep_gemm_expert is an Optional (its existence depends on the self.allow_deep_gemm variable - that is checked right above this line) - let me see if an assert right above fixes it. - also didn't unnecessarily want to introduce an assert on the hot-path.

Comment on lines 324 to 325
num_nvl_bytes=1024 * 1024 * 1024, # 1Gb
num_rdma_bytes=0,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this right for the multinode case? Thinking we might need to set num_rdma_bytes > 0 in that case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are correct - It does need to be >1 for the internode case. Let me fix that 👍 nice catch !

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tlrmchlsmth - fixed it in b841fac - Got the defaults from the DeepEP tests. We can update it if need be.

num_nvl_bytes=1024 * 1024 * 1024, # 1Gb
num_rdma_bytes=0,
low_latency_mode=False,
num_qps_per_rank=1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@varun-sundar-rabindranath do you know what this argument is for?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from the docs, this is the number of parallel RDMA connections each rank can establish.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is set to NVSHMEM_IBGDA_NUM_RC_PER_PE environment variable in code.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean that we should be reading the NVSHMEM_IBGDA_NUM_RC_PER_PE env and passing it in here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines +90 to +99
n_tiles_w1 = ((2 * n) + block_n - 1) // block_n
k_tiles_w1 = (k + block_k - 1) // block_k
n_tiles_w2 = (k + block_n - 1) // block_n
k_tiles_w2 = (n + block_k - 1) // block_k
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: More readable to use e.g. k_tiles_w1 = round_up(k, block_k)

vllm/vllm/utils.py

Lines 729 to 730 in c57d577

def round_up(x: int, y: int) -> int:
return ((x + y - 1) // y) * y

Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall.

I had some questions on the construction of the all_to_all_args for the HT and LL cases -- want to make sure we're good on num_rdma_bytes, num_qps_per_rank before landing.

Other stuff is pretty minor

apply_router_weight_on_input: bool,
output_dtype: torch.dtype):

if fused_expert_output.ndim == 2:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would fused_expert_output have varying ndim?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the DeepEP high-throughput dispatch kernel does not give batched output - as a result we end up using the TritonOrDeemGemmExperts - the output of that "experts" is 2 dim.

max_tokens_per_rank: int,
quant_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None,
use_fp8_dispatch: bool = False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this flag indicate that the input has already been quantized?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. This is a performance related option in the low-latency kernels.
The low-latency Dispatch kernel can only dispatch bfloat16.
This option informs the kernel to quantize the inputs internally and dispatch them fp8. The kernel outputs tokens, and scales which we dequantize in the receiving end.

Comment on lines +96 to +101
if apply_router_weight_on_input:
topk = rank_topk_ids.size(1)
# TODO: this only works for topK=1, will need to update for topK>1
assert topk == 1, (
"apply_router_weight_on_input is only implemented for topk=1")
a1 = a1 * rank_topk_weights.to(a1.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like this snippet has been repeated enough that we should make a utility out of it at some point. We can save it for a later PR tho.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah. I felt the same. something like maybe_apply_router_weight_on_input - But I agree we can put it in a later PR.

Comment on lines 438 to 442
use_batched_experts = (
isinstance(prepare_finalize, BatchedPrepareAndFinalize) or
(has_pplx and isinstance(prepare_finalize, PplxPrepareAndFinalize))
or (has_deepep
and isinstance(prepare_finalize, DeepEPLLPrepareAndFinalize)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add a method to the prepare and finalize class that returns the activation format, e.g. expert-batched vs. non-batched. Then we won't need all the isinstance checks. I'm fine with doing this in another PR tho.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed it 👍

Comment on lines +926 to +915
if (self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't applicable for the ht kernels?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. The HT kernels are not batched.

Comment on lines +1322 to +1311
if (self.use_pplx_kernels or self.use_deepep_ht_kernels
or self.use_deepep_ll_kernels):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should add a new umbrella property for all these types of kernels?

@@ -1305,12 +1408,17 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False):
def forward_impl(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
assert self.quant_method is not None
if self.moe_parallel_config.use_pplx_kernels:
if (self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No ht kernels here either?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. the HT kernels aren't batched.

Comment on lines +1415 to +1406
do_naive_dispatch_combine: bool = (
self.dp_size > 1
and not self.moe_parallel_config.use_deepep_ht_kernels)
if do_naive_dispatch_combine:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another future TODO, put the naive dispatch/combine into a NaivePrepareAndFinalize class. (I was planning on doing this but thought I'd mention it for posterity).

@@ -459,8 +462,10 @@ def __init__(self, quant_config: Fp8Config):
logger.warning_once(
"DeepGemm not supported on the current platform.")

self.topk_indices_dtype = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this initialized in the base class and set when select_gemm_impl is called? Maybe we can do away with this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is an edge case actually - the case with DP=1 TP=2 and enable_expert_parallel -- there the select_gemm_impl isn't called at all. I ran into this when testing.

Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me after the latest commit

@tlrmchlsmth tlrmchlsmth added the ready ONLY add when PR is ready to merge/full CI is needed label Jun 3, 2025
Varun Sundar Rabindranath added 2 commits June 3, 2025 04:47
Signed-off-by: Varun <[email protected]>
Signed-off-by: Varun <[email protected]>
Varun added 3 commits June 3, 2025 05:39
Signed-off-by: Varun <[email protected]>
Signed-off-by: Varun <[email protected]>
@simon-mo simon-mo merged commit fa98d77 into vllm-project:main Jun 3, 2025
93 of 97 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants