@@ -84,6 +84,7 @@ def __init__(self, config: Qwen3MoeConfig):
84
84
config .hidden_size ,
85
85
top_k = config .num_experts_per_tok ,
86
86
drop_tokens = False ,
87
+ norm_topk_prob = config .norm_topk_prob ,
87
88
)
88
89
89
90
super ().__init__ (
@@ -148,7 +149,7 @@ def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor:
148
149
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
149
150
150
151
current_state = hidden_states [idx , None ].reshape ([- 1 , hidden_dim ])
151
- current_hidden_states = expert_layer (current_state ) * routing_weights [idx , top_x ]
152
+ current_hidden_states = expert_layer (current_state ) * routing_weights [idx , top_x ]. unsqueeze ( - 1 )
152
153
final_hidden_states .index_add_ (
153
154
index = idx .reshape ([- 1 ]), axis = 0 , value = current_hidden_states .to (hidden_states .dtype )
154
155
)
@@ -165,7 +166,7 @@ def __init__(self, config: Qwen3MoeConfig, layerwise_recompute: bool = False):
165
166
self .self_attn = Qwen3MoeAttention (config , layerwise_recompute )
166
167
167
168
if config .num_experts > 0 :
168
- self .mlp = Qwen3MoeSparseMoeBlock (config )
169
+ self .mlp = ExpertParallelQwen3MoeSparseMoeBlock (config )
169
170
else :
170
171
# num_experts == 0 or this layer is not sparse layer
171
172
self .mlp = Qwen3MoeMLP (config )
@@ -828,7 +829,7 @@ def prepare_inputs_for_generation(
828
829
attention_mask = None ,
829
830
inputs_embeds = None ,
830
831
output_router_logits = False ,
831
- ** kwargs
832
+ ** kwargs ,
832
833
):
833
834
batch_size , seq_length = input_ids .shape
834
835
position_ids = kwargs .get ("position_ids" , paddle .arange (seq_length ).expand ((batch_size , seq_length )))
0 commit comments