Skip to content

Commit 20d07b3

Browse files
committed
Excluding AdamWeightDecayOptimizer internal variables from restoring
1 parent 278fd28 commit 20d07b3

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

convert_tf_checkpoint_to_pytorch.py

100644100755
Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,17 @@ def convert():
6868
arrays.append(array)
6969

7070
for name, array in zip(names, arrays):
71-
name = name[5:] # skip "bert/"
71+
if not name.startswith("bert"):
72+
print("Skipping {}".format(name))
73+
continue
74+
else:
75+
name = name.replace("bert/", "") # skip "bert/"
7276
print("Loading {}".format(name))
7377
name = name.split('/')
74-
if name[0] in ['redictions', 'eq_relationship']:
75-
print("Skipping")
78+
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
79+
# which are not required for using pretrained model
80+
if name[0] in ['redictions', 'eq_relationship'] or name[-1] == "adam_v" or name[-1] == "adam_m":
81+
print("Skipping {}".format("/".join(name)))
7682
continue
7783
pointer = model
7884
for m_name in name:

0 commit comments

Comments
 (0)