Skip to content

Commit 6b9500b

Browse files
authored
Merge pull request #11 from microsoft/jingywa/hfbert-changes
Bert type cast fix
2 parents ae1411f + 25e7be2 commit 6b9500b

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

src/transformers/models/bert/modeling_bert.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1266,6 +1266,7 @@ def __init__(self, config):
12661266

12671267
self.bert = BertModel(config, add_pooling_layer=False)
12681268
self.cls = BertOnlyMLMHead(config)
1269+
self.ort = config.ort
12691270

12701271
self.init_weights()
12711272

@@ -1326,7 +1327,10 @@ def forward(
13261327
masked_lm_loss = None
13271328
if labels is not None:
13281329
loss_fct = CrossEntropyLoss() # -100 index = padding token
1329-
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1330+
if self.ort:
1331+
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size).to(torch.float32), labels.view(-1))
1332+
else:
1333+
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
13301334

13311335
if not return_dict:
13321336
output = (prediction_scores,) + outputs[2:]

0 commit comments

Comments
 (0)