Skip to content

Commit 7336896

Browse files
monologgLysandreJik
authored andcommitted
Fix importing unofficial TF models with extra optimizer weights
1 parent d7dabfe commit 7336896

File tree

4 files changed

+19
-4
lines changed

4 files changed

+19
-4
lines changed

src/transformers/modeling_albert.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,13 @@ def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
117117
name = name.split("/")
118118

119119
# Ignore the gradients applied by the LAMB/ADAM optimizers.
120-
if "adam_m" in name or "adam_v" in name or "global_step" in name:
120+
if (
121+
"adam_m" in name
122+
or "adam_v" in name
123+
or "AdamWeightDecayOptimizer" in name
124+
or "AdamWeightDecayOptimizer_1" in name
125+
or "global_step" in name
126+
):
121127
logger.info("Skipping {}".format("/".join(name)))
122128
continue
123129

src/transformers/modeling_bert.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,10 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
8686
name = name.split("/")
8787
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
8888
# which are not required for using pretrained model
89-
if any(n in ["adam_v", "adam_m", "global_step"] for n in name):
89+
if any(
90+
n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
91+
for n in name
92+
):
9093
logger.info("Skipping {}".format("/".join(name)))
9194
continue
9295
pointer = model

src/transformers/modeling_t5.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,10 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path):
7979
name = txt_name.split("/")
8080
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
8181
# which are not required for using pretrained model
82-
if any(n in ["adam_v", "adam_m", "global_step"] for n in name):
82+
if any(
83+
n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
84+
for n in name
85+
):
8386
logger.info("Skipping {}".format("/".join(name)))
8487
tf_weights.pop(txt_name, None)
8588
continue

templates/adding_a_new_model/modeling_xxx.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,10 @@ def load_tf_weights_in_xxx(model, config, tf_checkpoint_path):
7676
name = name.split("/")
7777
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
7878
# which are not required for using pretrained model
79-
if any(n in ["adam_v", "adam_m", "global_step"] for n in name):
79+
if any(
80+
n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
81+
for n in name
82+
):
8083
logger.info("Skipping {}".format("/".join(name)))
8184
continue
8285
pointer = model

0 commit comments

Comments
 (0)