Skip to content

Commit e7d8492

Browse files
authored
Merge pull request #3 from DrownFish19/dev_20250505_add_qwen3
Fix EP for Qwen3Moe
2 parents 8602183 + f060f79 commit e7d8492

File tree

2 files changed

+12
-9
lines changed

2 files changed

+12
-9
lines changed

paddlenlp/transformers/moe_gate.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def _priority(self, topk_idx: paddle.Tensor, capacity: int) -> paddle.Tensor:
226226
chosen_expert = topk_idx.reshape([-1])
227227
# Shape: [seq_len * k, num_experts].
228228
token_priority = F.one_hot(chosen_expert, self.num_experts).cast(paddle.int32)
229-
token_priority = paddle.logical_and(token_priority > 0, token_priority.cumsum(axis=0) < capacity)
229+
token_priority = paddle.logical_and(token_priority > 0, token_priority.cumsum(axis=0) <= capacity)
230230
# Shape: [seq_len, num_experts].
231231
token_priority = token_priority.reshape([-1, k, self.num_experts]).sum(axis=1)
232232

@@ -532,12 +532,14 @@ def topkgating(
532532
token_priority = self._priority(top_idx, capacity)
533533

534534
# normalize gates
535+
# gates_masked is equal to top_gate.
535536
gates_masked = gates * mask
536-
if self.training:
537-
gates_s = paddle.sum(gates_masked, axis=-1, keepdim=True)
538-
denom_s = paddle.clip(gates_s, min=paddle.finfo(gates_masked.dtype).eps)
539-
if self.norm_topk_prob:
540-
gates_masked = gates_masked / denom_s
537+
# if self.training:
538+
gates_s = paddle.sum(gates_masked, axis=-1, keepdim=True)
539+
denom_s = paddle.clip(gates_s, min=paddle.finfo(gates_masked.dtype).eps)
540+
if self.norm_topk_prob:
541+
gates_masked = gates_masked / denom_s
542+
gates_masked *= self.routed_scaling_factor
541543

542544
return (
543545
capacity,

paddlenlp/transformers/qwen3_moe/modeling.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def __init__(self, config: Qwen3MoeConfig):
8484
config.hidden_size,
8585
top_k=config.num_experts_per_tok,
8686
drop_tokens=False,
87+
norm_topk_prob=config.norm_topk_prob,
8788
)
8889

8990
super().__init__(
@@ -148,7 +149,7 @@ def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor:
148149
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
149150

150151
current_state = hidden_states[idx, None].reshape([-1, hidden_dim])
151-
current_hidden_states = expert_layer(current_state) * routing_weights[idx, top_x]
152+
current_hidden_states = expert_layer(current_state) * routing_weights[idx, top_x].unsqueeze(-1)
152153
final_hidden_states.index_add_(
153154
index=idx.reshape([-1]), axis=0, value=current_hidden_states.to(hidden_states.dtype)
154155
)
@@ -165,7 +166,7 @@ def __init__(self, config: Qwen3MoeConfig, layerwise_recompute: bool = False):
165166
self.self_attn = Qwen3MoeAttention(config, layerwise_recompute)
166167

167168
if config.num_experts > 0:
168-
self.mlp = Qwen3MoeSparseMoeBlock(config)
169+
self.mlp = ExpertParallelQwen3MoeSparseMoeBlock(config)
169170
else:
170171
# num_experts == 0 or this layer is not sparse layer
171172
self.mlp = Qwen3MoeMLP(config)
@@ -828,7 +829,7 @@ def prepare_inputs_for_generation(
828829
attention_mask=None,
829830
inputs_embeds=None,
830831
output_router_logits=False,
831-
**kwargs
832+
**kwargs,
832833
):
833834
batch_size, seq_length = input_ids.shape
834835
position_ids = kwargs.get("position_ids", paddle.arange(seq_length).expand((batch_size, seq_length)))

0 commit comments

Comments
 (0)