@@ -497,7 +497,10 @@ def grouped_topk(hidden_states: torch.Tensor,
497
497
raise ValueError (f"Unsupported scoring function: { scoring_func } " )
498
498
499
499
if e_score_correction_bias is not None :
500
- scores .add_ (e_score_correction_bias .unsqueeze (0 ))
500
+ # Store original scores before applying correction bias. We use biased
501
+ # scores for expert selection but original scores for routing weights
502
+ original_scores = scores
503
+ scores = scores + e_score_correction_bias .unsqueeze (0 )
501
504
502
505
num_token = scores .shape [0 ]
503
506
group_scores = scores .view (num_token , num_expert_group ,
@@ -510,10 +513,16 @@ def grouped_topk(hidden_states: torch.Tensor,
510
513
num_token , num_expert_group ,
511
514
scores .shape [- 1 ] // num_expert_group ).reshape (num_token , - 1 ) # [n, e]
512
515
tmp_scores = scores .masked_fill (~ score_mask .bool (), 0.0 ) # [n, e]
513
- topk_weights , topk_ids = torch .topk (tmp_scores ,
514
- k = topk ,
515
- dim = - 1 ,
516
- sorted = False )
516
+
517
+ if e_score_correction_bias is not None :
518
+ topk_ids = torch .topk (tmp_scores , k = topk , dim = - 1 , sorted = False )[1 ]
519
+ # Use original unbiased scores for the routing weights
520
+ topk_weights = original_scores .gather (1 , topk_ids )
521
+ else :
522
+ topk_weights , topk_ids = torch .topk (tmp_scores ,
523
+ k = topk ,
524
+ dim = - 1 ,
525
+ sorted = False )
517
526
518
527
if renormalize :
519
528
topk_weights = topk_weights / topk_weights .sum (dim = - 1 , keepdim = True )
0 commit comments