@@ -388,10 +388,10 @@ def init_weights(module):
388
388
if isinstance (module , (nn .Linear , nn .Embedding )):
389
389
# Slightly different from the TF version which uses truncated_normal for initialization
390
390
# cf https://github.com/pytorch/pytorch/pull/5617
391
- module .weight .data .normal_ (config .initializer_range )
391
+ module .weight .data .normal_ (mean = 0.0 , std = config .initializer_range )
392
392
elif isinstance (module , BERTLayerNorm ):
393
- module .beta .data .normal_ (config .initializer_range )
394
- module .gamma .data .normal_ (config .initializer_range )
393
+ module .beta .data .normal_ (mean = 0.0 , std = config .initializer_range )
394
+ module .gamma .data .normal_ (mean = 0.0 , std = config .initializer_range )
395
395
if isinstance (module , nn .Linear ):
396
396
module .bias .data .zero_ ()
397
397
self .apply (init_weights )
@@ -438,10 +438,10 @@ def init_weights(module):
438
438
if isinstance (module , (nn .Linear , nn .Embedding )):
439
439
# Slightly different from the TF version which uses truncated_normal for initialization
440
440
# cf https://github.com/pytorch/pytorch/pull/5617
441
- module .weight .data .normal_ (config .initializer_range )
441
+ module .weight .data .normal_ (mean = 0.0 , std = config .initializer_range )
442
442
elif isinstance (module , BERTLayerNorm ):
443
- module .beta .data .normal_ (config .initializer_range )
444
- module .gamma .data .normal_ (config .initializer_range )
443
+ module .beta .data .normal_ (mean = 0.0 , std = config .initializer_range )
444
+ module .gamma .data .normal_ (mean = 0.0 , std = config .initializer_range )
445
445
if isinstance (module , nn .Linear ):
446
446
module .bias .data .zero_ ()
447
447
self .apply (init_weights )
@@ -459,7 +459,7 @@ def forward(self, input_ids, token_type_ids, attention_mask, start_positions=Non
459
459
start_positions = start_positions .squeeze (- 1 )
460
460
end_positions = end_positions .squeeze (- 1 )
461
461
# sometimes the start/end positions are outside our model inputs, we ignore these terms
462
- ignored_index = start_logits .size (1 ) + 1
462
+ ignored_index = start_logits .size (1 )
463
463
start_positions .clamp_ (0 , ignored_index )
464
464
end_positions .clamp_ (0 , ignored_index )
465
465
0 commit comments