Skip to content

Commit 466aaae

Browse files
SunflowerAriesmgoin
authored andcommitted
[Bugfix] Fix deepseekv3 gate bias error (vllm-project#12002)
Signed-off-by: mgoin <[email protected]> Co-authored-by: mgoin <[email protected]>
1 parent cc98e18 commit 466aaae

File tree

1 file changed

+14
-5
lines changed

1 file changed

+14
-5
lines changed

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -497,7 +497,10 @@ def grouped_topk(hidden_states: torch.Tensor,
497497
raise ValueError(f"Unsupported scoring function: {scoring_func}")
498498

499499
if e_score_correction_bias is not None:
500-
scores.add_(e_score_correction_bias.unsqueeze(0))
500+
# Store original scores before applying correction bias. We use biased
501+
# scores for expert selection but original scores for routing weights
502+
original_scores = scores
503+
scores = scores + e_score_correction_bias.unsqueeze(0)
501504

502505
num_token = scores.shape[0]
503506
group_scores = scores.view(num_token, num_expert_group,
@@ -510,10 +513,16 @@ def grouped_topk(hidden_states: torch.Tensor,
510513
num_token, num_expert_group,
511514
scores.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e]
512515
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
513-
topk_weights, topk_ids = torch.topk(tmp_scores,
514-
k=topk,
515-
dim=-1,
516-
sorted=False)
516+
517+
if e_score_correction_bias is not None:
518+
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1]
519+
# Use original unbiased scores for the routing weights
520+
topk_weights = original_scores.gather(1, topk_ids)
521+
else:
522+
topk_weights, topk_ids = torch.topk(tmp_scores,
523+
k=topk,
524+
dim=-1,
525+
sorted=False)
517526

518527
if renormalize:
519528
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)

0 commit comments

Comments
 (0)