@@ -126,8 +126,7 @@ def scaled_dot_product_attention(
126
126
key_states = paddle .transpose (key_states , [0 , 2 , 1 , 3 ])
127
127
value_states = paddle .transpose (value_states , [0 , 2 , 1 , 3 ])
128
128
129
- attn_weights = paddle .matmul (query_states , key_states .transpose ([0 , 1 , 3 , 2 ])) / math .sqrt (head_dim )
130
-
129
+ attn_weights = paddle .matmul (query_states / math .sqrt (head_dim ), key_states .transpose ([0 , 1 , 3 , 2 ]))
131
130
if attn_weights .shape != [bsz , num_heads , q_len , kv_seq_len ]:
132
131
raise ValueError (
133
132
f"Attention weights should be of shape { (bsz , num_heads , q_len , kv_seq_len )} , but is"
@@ -746,8 +745,10 @@ def __init__(self, tensor_parallel_degree=1, tensor_parallel_output=False):
746
745
747
746
def forward (self , prediction_scores , masked_lm_labels , ignore_index = - 100 ):
748
747
masked_lm_loss = self .loss_func (prediction_scores , masked_lm_labels .unsqueeze (2 ))
749
- masked_lm_loss = masked_lm_loss [masked_lm_labels != ignore_index ]
750
- loss = paddle .mean (masked_lm_loss )
748
+ with paddle .amp .auto_cast (False ):
749
+ masked_lm_loss = masked_lm_loss .astype ("float32" )
750
+ masked_lm_loss = masked_lm_loss [masked_lm_labels != ignore_index ]
751
+ loss = paddle .mean (masked_lm_loss )
751
752
return loss
752
753
753
754
0 commit comments