Skip to content

Commit db10422

Browse files
jinzhen-linmgoin
andauthored
[Bugfix] fix deepseek fp16 scale bug (#14809)
Signed-off-by: Jinzhen Lin <[email protected]> Co-authored-by: mgoin <[email protected]>
1 parent e1a2c69 commit db10422

File tree

1 file changed

+24
-10
lines changed

1 file changed

+24
-10
lines changed

vllm/model_executor/models/deepseek_v2.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -160,14 +160,16 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
160160
hidden_states=hidden_states,
161161
router_logits=router_logits) * self.routed_scaling_factor
162162
else:
163-
# This is a special case to avoid FP16 overflow
163+
# Fix FP16 overflow
164+
# See DeepseekV2DecoderLayer for more details.
164165
final_hidden_states = self.experts(hidden_states=hidden_states,
165166
router_logits=router_logits)
166167
if shared_output is not None:
167168
if hidden_states.dtype != torch.float16:
168169
final_hidden_states = final_hidden_states + shared_output
169170
else:
170-
# This is a special case to avoid FP16 overflow
171+
# Fix FP16 overflow
172+
# See DeepseekV2DecoderLayer for more details.
171173
final_hidden_states = final_hidden_states + shared_output \
172174
* (1. / self.routed_scaling_factor)
173175
if self.tp_size > 1:
@@ -499,6 +501,7 @@ def __init__(
499501
# DecoderLayers are created with `make_layers` which passes the prefix
500502
# with the layer's index.
501503
layer_idx = int(prefix.split(sep='.')[-1])
504+
self.layer_idx = layer_idx
502505
if model_config.use_mla:
503506
attn_cls = DeepseekV2MLAAttention
504507
else:
@@ -561,19 +564,30 @@ def forward(
561564
hidden_states=hidden_states,
562565
)
563566

564-
# Fully Connected
565-
if isinstance(self.mlp, DeepseekV2MoE) and \
566-
hidden_states.dtype == torch.float16:
567-
# This is a special case to avoid FP16 overflow
567+
if hidden_states.dtype == torch.float16:
568+
# Fix FP16 overflow
569+
# We scale both hidden_states and residual before
570+
# rmsnorm, and rmsnorm result would not affect by scale.
568571
hidden_states *= 1. / self.routed_scaling_factor
572+
if self.layer_idx == 0:
573+
# The residual is shared by all layers, we only scale it on
574+
# first layer.
575+
residual *= 1. / self.routed_scaling_factor
576+
577+
# Fully Connected
569578
hidden_states, residual = self.post_attention_layernorm(
570579
hidden_states, residual)
571580
hidden_states = self.mlp(hidden_states)
572-
if isinstance(self.mlp, DeepseekV2MLP) and \
573-
hidden_states.dtype == torch.float16:
574-
# This is a special case to avoid FP16 overflow
581+
582+
if isinstance(self.mlp,
583+
DeepseekV2MLP) and hidden_states.dtype == torch.float16:
584+
# Fix FP16 overflow
585+
# Scaling the DeepseekV2MLP output, it is the input of
586+
# input_layernorm of next decoder layer.
587+
# The scaling of DeepseekV2MOE output would be done in the forward
588+
# of DeepseekV2MOE
575589
hidden_states *= 1. / self.routed_scaling_factor
576-
residual *= 1. / self.routed_scaling_factor
590+
577591
return hidden_states, residual
578592

579593

0 commit comments

Comments
 (0)