File tree 2 files changed +2
-10
lines changed 2 files changed +2
-10
lines changed Original file line number Diff line number Diff line change @@ -1266,7 +1266,6 @@ def __init__(self, config):
1266
1266
1267
1267
self .bert = BertModel (config , add_pooling_layer = False )
1268
1268
self .cls = BertOnlyMLMHead (config )
1269
- self .ort = config .ort
1270
1269
1271
1270
self .init_weights ()
1272
1271
@@ -1327,10 +1326,7 @@ def forward(
1327
1326
masked_lm_loss = None
1328
1327
if labels is not None :
1329
1328
loss_fct = CrossEntropyLoss () # -100 index = padding token
1330
- if self .ort :
1331
- masked_lm_loss = loss_fct (prediction_scores .view (- 1 , self .config .vocab_size ).to (torch .float32 ), labels .view (- 1 ))
1332
- else :
1333
- masked_lm_loss = loss_fct (prediction_scores .view (- 1 , self .config .vocab_size ), labels .view (- 1 ))
1329
+ masked_lm_loss = loss_fct (prediction_scores .view (- 1 , self .config .vocab_size ), labels .view (- 1 ))
1334
1330
1335
1331
if not return_dict :
1336
1332
output = (prediction_scores ,) + outputs [2 :]
Original file line number Diff line number Diff line change @@ -500,7 +500,6 @@ def __init__(self, config):
500
500
self .vocab_transform = nn .Linear (config .dim , config .dim )
501
501
self .vocab_layer_norm = nn .LayerNorm (config .dim , eps = 1e-12 )
502
502
self .vocab_projector = nn .Linear (config .dim , config .vocab_size )
503
- self .ort = config .ort
504
503
505
504
self .init_weights ()
506
505
@@ -555,10 +554,7 @@ def forward(
555
554
556
555
mlm_loss = None
557
556
if labels is not None :
558
- if self .ort :
559
- mlm_loss = self .mlm_loss_fct (prediction_logits .view (- 1 , prediction_logits .size (- 1 )).to (torch .float32 ), labels .view (- 1 ))
560
- else :
561
- mlm_loss = self .mlm_loss_fct (prediction_logits .view (- 1 , prediction_logits .size (- 1 )), labels .view (- 1 ))
557
+ mlm_loss = self .mlm_loss_fct (prediction_logits .view (- 1 , prediction_logits .size (- 1 )), labels .view (- 1 ))
562
558
563
559
if not return_dict :
564
560
output = (prediction_logits ,) + dlbrt_output [1 :]
You can’t perform that action at this time.
0 commit comments