-
-
Notifications
You must be signed in to change notification settings - Fork 7.8k
[Kernel] DeepEP dispatch-combine kernel integration #18434
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
This pull request has merge conflicts that must be resolved before it can be |
3bc6ab7
to
74be347
Compare
vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
Outdated
Show resolved
Hide resolved
This pull request has merge conflicts that must be resolved before it can be |
vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
Outdated
Show resolved
Hide resolved
vllm/model_executor/layers/fused_moe/deepep_prepare_finalize.py
Outdated
Show resolved
Hide resolved
# TODO (varun) : deepgemm integration | ||
self.use_batched_experts = False | ||
if envs.VLLM_ALL2ALL_BACKEND == "deepep_ll": | ||
self.use_batched_experts = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be better to add a method to prepare_finalize
that says whether or not the format is batched or not, instead of using an env var or checking the instance type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
introduced a function max_num_tokens_per_rank()
to the prepare_finalize objects - We can use it to determine batching 👍
17fa499
to
3fa74f4
Compare
vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
Outdated
Show resolved
Hide resolved
vllm/model_executor/layers/fused_moe/deepep_prepare_finalize.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. I just left a few minor comments. Going to try this out in a multinode setup
7a129ec
to
62916dd
Compare
# weights have already been applied. | ||
combine_topk_weights = torch.ones_like(topk_weights) | ||
|
||
# TODO (varun) : Enable zero copy mode |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still TODO?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah. It could be fast-follow.
vllm/model_executor/layers/fused_moe/deepep_prepare_finalize.py
Outdated
Show resolved
Hide resolved
and _valid_deep_gemm(hidden_states, w1, w2, expert_map)): | ||
return self.deep_gemm_expert.apply( | ||
and _valid_deep_gemm(hidden_states, w1, w2)): | ||
return self.deep_gemm_expert.apply( #type: ignore |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need the #type: ignore
? Could you add a comment (or better: resolve the type issues)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
deep_gemm_expert is an Optional (its existence depends on the self.allow_deep_gemm
variable - that is checked right above this line) - let me see if an assert right above fixes it. - also didn't unnecessarily want to introduce an assert on the hot-path.
num_nvl_bytes=1024 * 1024 * 1024, # 1Gb | ||
num_rdma_bytes=0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this right for the multinode case? Thinking we might need to set num_rdma_bytes > 0
in that case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are correct - It does need to be >1 for the internode case. Let me fix that 👍 nice catch !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tlrmchlsmth - fixed it in b841fac - Got the defaults from the DeepEP tests. We can update it if need be.
num_nvl_bytes=1024 * 1024 * 1024, # 1Gb | ||
num_rdma_bytes=0, | ||
low_latency_mode=False, | ||
num_qps_per_rank=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@varun-sundar-rabindranath do you know what this argument is for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from the docs, this is the number of parallel RDMA connections each rank can establish.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is set to NVSHMEM_IBGDA_NUM_RC_PER_PE
environment variable in code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean that we should be reading the NVSHMEM_IBGDA_NUM_RC_PER_PE
env and passing it in here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no - what is passed here is set as that env var - here -> https://github.com/deepseek-ai/DeepEP/blob/9fe9021f29c9083cd1808ab36b740208524d9f63/deep_ep/buffer.py#L79
vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py
Outdated
Show resolved
Hide resolved
n_tiles_w1 = ((2 * n) + block_n - 1) // block_n | ||
k_tiles_w1 = (k + block_k - 1) // block_k | ||
n_tiles_w2 = (k + block_n - 1) // block_n | ||
k_tiles_w2 = (n + block_k - 1) // block_k |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: More readable to use e.g. k_tiles_w1 = round_up(k, block_k)
Lines 729 to 730 in c57d577
def round_up(x: int, y: int) -> int: | |
return ((x + y - 1) // y) * y |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good overall.
I had some questions on the construction of the all_to_all_args
for the HT and LL cases -- want to make sure we're good on num_rdma_bytes
, num_qps_per_rank
before landing.
Other stuff is pretty minor
apply_router_weight_on_input: bool, | ||
output_dtype: torch.dtype): | ||
|
||
if fused_expert_output.ndim == 2: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why would fused_expert_output
have varying ndim?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the DeepEP high-throughput dispatch kernel does not give batched output - as a result we end up using the TritonOrDeemGemmExperts - the output of that "experts" is 2 dim.
max_tokens_per_rank: int, | ||
quant_dtype: Optional[torch.dtype] = None, | ||
block_shape: Optional[list[int]] = None, | ||
use_fp8_dispatch: bool = False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this flag indicate that the input has already been quantized?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No. This is a performance related option in the low-latency kernels.
The low-latency Dispatch kernel can only dispatch bfloat16.
This option informs the kernel to quantize the inputs internally and dispatch them fp8. The kernel outputs tokens, and scales which we dequantize in the receiving end.
if apply_router_weight_on_input: | ||
topk = rank_topk_ids.size(1) | ||
# TODO: this only works for topK=1, will need to update for topK>1 | ||
assert topk == 1, ( | ||
"apply_router_weight_on_input is only implemented for topk=1") | ||
a1 = a1 * rank_topk_weights.to(a1.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like this snippet has been repeated enough that we should make a utility out of it at some point. We can save it for a later PR tho.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah. I felt the same. something like maybe_apply_router_weight_on_input
- But I agree we can put it in a later PR.
use_batched_experts = ( | ||
isinstance(prepare_finalize, BatchedPrepareAndFinalize) or | ||
(has_pplx and isinstance(prepare_finalize, PplxPrepareAndFinalize)) | ||
or (has_deepep | ||
and isinstance(prepare_finalize, DeepEPLLPrepareAndFinalize))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should add a method to the prepare and finalize class that returns the activation format, e.g. expert-batched vs. non-batched. Then we won't need all the isinstance checks. I'm fine with doing this in another PR tho.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed it 👍
if (self.moe_parallel_config.use_pplx_kernels | ||
or self.moe_parallel_config.use_deepep_ll_kernels): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't applicable for the ht kernels?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No. The HT kernels are not batched.
if (self.use_pplx_kernels or self.use_deepep_ht_kernels | ||
or self.use_deepep_ll_kernels): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we should add a new umbrella property for all these types of kernels?
@@ -1305,12 +1408,17 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): | |||
def forward_impl(self, hidden_states: torch.Tensor, | |||
router_logits: torch.Tensor): | |||
assert self.quant_method is not None | |||
if self.moe_parallel_config.use_pplx_kernels: | |||
if (self.moe_parallel_config.use_pplx_kernels | |||
or self.moe_parallel_config.use_deepep_ll_kernels): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No ht kernels here either?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No. the HT kernels aren't batched.
do_naive_dispatch_combine: bool = ( | ||
self.dp_size > 1 | ||
and not self.moe_parallel_config.use_deepep_ht_kernels) | ||
if do_naive_dispatch_combine: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another future TODO, put the naive dispatch/combine into a NaivePrepareAndFinalize class. (I was planning on doing this but thought I'd mention it for posterity).
@@ -459,8 +462,10 @@ def __init__(self, quant_config: Fp8Config): | |||
logger.warning_once( | |||
"DeepGemm not supported on the current platform.") | |||
|
|||
self.topk_indices_dtype = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this initialized in the base class and set when select_gemm_impl is called? Maybe we can do away with this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there is an edge case actually - the case with DP=1 TP=2 and enable_expert_parallel
-- there the select_gemm_impl isn't called at all. I ran into this when testing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me after the latest commit
Signed-off-by: Varun <[email protected]>
Signed-off-by: Varun <[email protected]>
aaa6ee3
to
b03fd2c
Compare
Signed-off-by: Varun <[email protected]>
Signed-off-by: Varun <[email protected]>
Signed-off-by: Varun <[email protected]>
Integrate DeepEP dispatch-combine kernels
Correctness:
Models:
deepseek-ai/DeepSeek-V2-Lite
RedHatAI/DeepSeek-Coder-V2-Lite-Instruct-FP8
Qwen/Qwen3-30B-A3B-FP8
ALL2ALL Backend:
deepep_high_throughput
Cases: for DP=2 TP=1 case.
Models:
deepseek-ai/DeepSeek-V2-Lite
RedHatAI/DeepSeek-Coder-V2-Lite-Instruct-FP8
ALL2ALL Backend:
deepep_low_latency
Cases: for DP=2 TP=1 case.
Note: DeepEP Low Latency kernels are compiled only for a set of hidden-sizes. DeepSeekV2-lite hidden sizes are not among them. I had to update the DeepEP to support the hidden size to do this test.
Models:
deepseek-ai/DeepSeek-V2-Lite
ALL2ALL Backend:
pplx
Cases: DP=2 TP=1