Skip to content

Commit 2cb2441

Browse files
authored
scale before matmul (#5762)
1 parent e70e9d8 commit 2cb2441

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

paddlenlp/transformers/llama/modeling.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,7 @@ def scaled_dot_product_attention(
126126
key_states = paddle.transpose(key_states, [0, 2, 1, 3])
127127
value_states = paddle.transpose(value_states, [0, 2, 1, 3])
128128

129-
attn_weights = paddle.matmul(query_states, key_states.transpose([0, 1, 3, 2])) / math.sqrt(head_dim)
130-
129+
attn_weights = paddle.matmul(query_states / math.sqrt(head_dim), key_states.transpose([0, 1, 3, 2]))
131130
if attn_weights.shape != [bsz, num_heads, q_len, kv_seq_len]:
132131
raise ValueError(
133132
f"Attention weights should be of shape {(bsz, num_heads, q_len, kv_seq_len)}, but is"
@@ -746,8 +745,10 @@ def __init__(self, tensor_parallel_degree=1, tensor_parallel_output=False):
746745

747746
def forward(self, prediction_scores, masked_lm_labels, ignore_index=-100):
748747
masked_lm_loss = self.loss_func(prediction_scores, masked_lm_labels.unsqueeze(2))
749-
masked_lm_loss = masked_lm_loss[masked_lm_labels != ignore_index]
750-
loss = paddle.mean(masked_lm_loss)
748+
with paddle.amp.auto_cast(False):
749+
masked_lm_loss = masked_lm_loss.astype("float32")
750+
masked_lm_loss = masked_lm_loss[masked_lm_labels != ignore_index]
751+
loss = paddle.mean(masked_lm_loss)
751752
return loss
752753

753754

0 commit comments

Comments
 (0)