File tree 1 file changed +5
-1
lines changed
src/transformers/models/bert
1 file changed +5
-1
lines changed Original file line number Diff line number Diff line change @@ -1266,6 +1266,7 @@ def __init__(self, config):
1266
1266
1267
1267
self .bert = BertModel (config , add_pooling_layer = False )
1268
1268
self .cls = BertOnlyMLMHead (config )
1269
+ self .ort = config .ort
1269
1270
1270
1271
self .init_weights ()
1271
1272
@@ -1326,7 +1327,10 @@ def forward(
1326
1327
masked_lm_loss = None
1327
1328
if labels is not None :
1328
1329
loss_fct = CrossEntropyLoss () # -100 index = padding token
1329
- masked_lm_loss = loss_fct (prediction_scores .view (- 1 , self .config .vocab_size ), labels .view (- 1 ))
1330
+ if self .ort :
1331
+ masked_lm_loss = loss_fct (prediction_scores .view (- 1 , self .config .vocab_size ).to (torch .float32 ), labels .view (- 1 ))
1332
+ else :
1333
+ masked_lm_loss = loss_fct (prediction_scores .view (- 1 , self .config .vocab_size ), labels .view (- 1 ))
1330
1334
1331
1335
if not return_dict :
1332
1336
output = (prediction_scores ,) + outputs [2 :]
You can’t perform that action at this time.
0 commit comments