Skip to content

Commit efc9019

Browse files
authored
Merge pull request #14 from microsoft/raviskolli/ort
Remove model specific changes for BERT and DistilBERT
2 parents 0b2532a + 239767d commit efc9019

File tree

2 files changed

+2
-10
lines changed

2 files changed

+2
-10
lines changed

src/transformers/models/bert/modeling_bert.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1266,7 +1266,6 @@ def __init__(self, config):
12661266

12671267
self.bert = BertModel(config, add_pooling_layer=False)
12681268
self.cls = BertOnlyMLMHead(config)
1269-
self.ort = config.ort
12701269

12711270
self.init_weights()
12721271

@@ -1327,10 +1326,7 @@ def forward(
13271326
masked_lm_loss = None
13281327
if labels is not None:
13291328
loss_fct = CrossEntropyLoss() # -100 index = padding token
1330-
if self.ort:
1331-
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size).to(torch.float32), labels.view(-1))
1332-
else:
1333-
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1329+
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
13341330

13351331
if not return_dict:
13361332
output = (prediction_scores,) + outputs[2:]

src/transformers/models/distilbert/modeling_distilbert.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -500,7 +500,6 @@ def __init__(self, config):
500500
self.vocab_transform = nn.Linear(config.dim, config.dim)
501501
self.vocab_layer_norm = nn.LayerNorm(config.dim, eps=1e-12)
502502
self.vocab_projector = nn.Linear(config.dim, config.vocab_size)
503-
self.ort = config.ort
504503

505504
self.init_weights()
506505

@@ -555,10 +554,7 @@ def forward(
555554

556555
mlm_loss = None
557556
if labels is not None:
558-
if self.ort:
559-
mlm_loss = self.mlm_loss_fct(prediction_logits.view(-1, prediction_logits.size(-1)).to(torch.float32), labels.view(-1))
560-
else:
561-
mlm_loss = self.mlm_loss_fct(prediction_logits.view(-1, prediction_logits.size(-1)), labels.view(-1))
557+
mlm_loss = self.mlm_loss_fct(prediction_logits.view(-1, prediction_logits.size(-1)), labels.view(-1))
562558

563559
if not return_dict:
564560
output = (prediction_logits,) + dlbrt_output[1:]

0 commit comments

Comments
 (0)