Skip to content

Commit fe861e5

Browse files
younesbelkadasgugger
authored andcommitted
[GPT2] Add correct keys on _keys_to_ignore_on_load_unexpected on all child classes of GPT2PreTrainedModel (#24113)
* add correct keys on `_keys_to_ignore_on_load_unexpected` * oops
1 parent b3e27a8 commit fe861e5

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

src/transformers/models/gpt2/modeling_gpt2.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -668,7 +668,8 @@ class GPT2DoubleHeadsModelOutput(ModelOutput):
668668
GPT2_START_DOCSTRING,
669669
)
670670
class GPT2Model(GPT2PreTrainedModel):
671-
_keys_to_ignore_on_load_missing = ["attn.masked_bias"]
671+
_keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.bias", r"h\.\d+\.attn\.masked_bias"]
672+
_keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias"]
672673

673674
def __init__(self, config):
674675
super().__init__(config)
@@ -1149,6 +1150,7 @@ def _reorder_cache(
11491150
GPT2_START_DOCSTRING,
11501151
)
11511152
class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
1153+
_keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.bias", r"h\.\d+\.attn\.masked_bias"]
11521154
_keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"]
11531155

11541156
def __init__(self, config):
@@ -1377,6 +1379,7 @@ def _reorder_cache(
13771379
GPT2_START_DOCSTRING,
13781380
)
13791381
class GPT2ForSequenceClassification(GPT2PreTrainedModel):
1382+
_keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.bias", r"h\.\d+\.attn\.masked_bias"]
13801383
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head.weight"]
13811384

13821385
def __init__(self, config):
@@ -1600,6 +1603,7 @@ def forward(
16001603
GPT2_START_DOCSTRING,
16011604
)
16021605
class GPT2ForQuestionAnswering(GPT2PreTrainedModel):
1606+
_keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.bias", r"h\.\d+\.attn\.masked_bias"]
16031607
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias", r"lm_head.weight"]
16041608

16051609
def __init__(self, config):

0 commit comments

Comments
 (0)