Skip to content

Commit 2a97fe2

Browse files
committed
fixing weights initialization in the model and out of span clamping
1 parent 907d356 commit 2a97fe2

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

modeling.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -388,10 +388,10 @@ def init_weights(module):
388388
if isinstance(module, (nn.Linear, nn.Embedding)):
389389
# Slightly different from the TF version which uses truncated_normal for initialization
390390
# cf https://github.com/pytorch/pytorch/pull/5617
391-
module.weight.data.normal_(config.initializer_range)
391+
module.weight.data.normal_(mean=0.0, std=config.initializer_range)
392392
elif isinstance(module, BERTLayerNorm):
393-
module.beta.data.normal_(config.initializer_range)
394-
module.gamma.data.normal_(config.initializer_range)
393+
module.beta.data.normal_(mean=0.0, std=config.initializer_range)
394+
module.gamma.data.normal_(mean=0.0, std=config.initializer_range)
395395
if isinstance(module, nn.Linear):
396396
module.bias.data.zero_()
397397
self.apply(init_weights)
@@ -438,10 +438,10 @@ def init_weights(module):
438438
if isinstance(module, (nn.Linear, nn.Embedding)):
439439
# Slightly different from the TF version which uses truncated_normal for initialization
440440
# cf https://github.com/pytorch/pytorch/pull/5617
441-
module.weight.data.normal_(config.initializer_range)
441+
module.weight.data.normal_(mean=0.0, std=config.initializer_range)
442442
elif isinstance(module, BERTLayerNorm):
443-
module.beta.data.normal_(config.initializer_range)
444-
module.gamma.data.normal_(config.initializer_range)
443+
module.beta.data.normal_(mean=0.0, std=config.initializer_range)
444+
module.gamma.data.normal_(mean=0.0, std=config.initializer_range)
445445
if isinstance(module, nn.Linear):
446446
module.bias.data.zero_()
447447
self.apply(init_weights)
@@ -459,7 +459,7 @@ def forward(self, input_ids, token_type_ids, attention_mask, start_positions=Non
459459
start_positions = start_positions.squeeze(-1)
460460
end_positions = end_positions.squeeze(-1)
461461
# sometimes the start/end positions are outside our model inputs, we ignore these terms
462-
ignored_index = start_logits.size(1) + 1
462+
ignored_index = start_logits.size(1)
463463
start_positions.clamp_(0, ignored_index)
464464
end_positions.clamp_(0, ignored_index)
465465

0 commit comments

Comments
 (0)