@@ -160,14 +160,16 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
160
160
hidden_states = hidden_states ,
161
161
router_logits = router_logits ) * self .routed_scaling_factor
162
162
else :
163
- # This is a special case to avoid FP16 overflow
163
+ # Fix FP16 overflow
164
+ # See DeepseekV2DecoderLayer for more details.
164
165
final_hidden_states = self .experts (hidden_states = hidden_states ,
165
166
router_logits = router_logits )
166
167
if shared_output is not None :
167
168
if hidden_states .dtype != torch .float16 :
168
169
final_hidden_states = final_hidden_states + shared_output
169
170
else :
170
- # This is a special case to avoid FP16 overflow
171
+ # Fix FP16 overflow
172
+ # See DeepseekV2DecoderLayer for more details.
171
173
final_hidden_states = final_hidden_states + shared_output \
172
174
* (1. / self .routed_scaling_factor )
173
175
if self .tp_size > 1 :
@@ -499,6 +501,7 @@ def __init__(
499
501
# DecoderLayers are created with `make_layers` which passes the prefix
500
502
# with the layer's index.
501
503
layer_idx = int (prefix .split (sep = '.' )[- 1 ])
504
+ self .layer_idx = layer_idx
502
505
if model_config .use_mla :
503
506
attn_cls = DeepseekV2MLAAttention
504
507
else :
@@ -561,19 +564,30 @@ def forward(
561
564
hidden_states = hidden_states ,
562
565
)
563
566
564
- # Fully Connected
565
- if isinstance ( self . mlp , DeepseekV2MoE ) and \
566
- hidden_states . dtype == torch . float16 :
567
- # This is a special case to avoid FP16 overflow
567
+ if hidden_states . dtype == torch . float16 :
568
+ # Fix FP16 overflow
569
+ # We scale both hidden_states and residual before
570
+ # rmsnorm, and rmsnorm result would not affect by scale.
568
571
hidden_states *= 1. / self .routed_scaling_factor
572
+ if self .layer_idx == 0 :
573
+ # The residual is shared by all layers, we only scale it on
574
+ # first layer.
575
+ residual *= 1. / self .routed_scaling_factor
576
+
577
+ # Fully Connected
569
578
hidden_states , residual = self .post_attention_layernorm (
570
579
hidden_states , residual )
571
580
hidden_states = self .mlp (hidden_states )
572
- if isinstance (self .mlp , DeepseekV2MLP ) and \
573
- hidden_states .dtype == torch .float16 :
574
- # This is a special case to avoid FP16 overflow
581
+
582
+ if isinstance (self .mlp ,
583
+ DeepseekV2MLP ) and hidden_states .dtype == torch .float16 :
584
+ # Fix FP16 overflow
585
+ # Scaling the DeepseekV2MLP output, it is the input of
586
+ # input_layernorm of next decoder layer.
587
+ # The scaling of DeepseekV2MOE output would be done in the forward
588
+ # of DeepseekV2MOE
575
589
hidden_states *= 1. / self .routed_scaling_factor
576
- residual *= 1. / self . routed_scaling_factor
590
+
577
591
return hidden_states , residual
578
592
579
593
0 commit comments