Skip to content

Commit 3118bc0

Browse files
Rocketknight1iantbutler01
authored andcommitted
Proper build() methods for TF (huggingface#27794)
* Add a convenience method for building in your own name scope * Second attempt at auto layer building * Revert "Second attempt at auto layer building" This reverts commit e03a3aa. * Attempt poedator#3 * Revert "Attempt poedator#3" This reverts commit b9df7a0. * Add missing attributes that we're going to need later * Add some attributes we're going to need later * A fourth attempt! Feel the power flow through you! * Revert "A fourth attempt! Feel the power flow through you!" This reverts commit 6bf4aaf. * Add more values we'll need later * TF refactor that we'll need later * Revert "TF refactor that we'll need later" This reverts commit ca07202. * Revert "Revert "TF refactor that we'll need later"" This reverts commit 1beb0f3. * make fixup * Attempt five! * Revert "Attempt five!" This reverts commit 3302207. * Attempt six - this time don't add empty methods * Revert "Attempt six - this time don't add empty methods" This reverts commit 67d6012. * Attempt seven - better base model class detection! * Revert "Attempt seven - better base model class detection!" This reverts commit 5f14845. * Another attribute we'll need later * Try again with the missing attribute! * Revert "Try again with the missing attribute!" This reverts commit 760c6f3. * This is the attempt that will pierce the heavens! * Revert "This is the attempt that will pierce the heavens!" This reverts commit c868bb6. * Attempt seven - snag list is steadily decreasing * Revert "Attempt seven - snag list is steadily decreasing" This reverts commit 46fbd97. * Attempt eight - will an empty snag list do it? * Revert "Attempt eight - will an empty snag list do it?" This reverts commit 7c8a3c2. * Fixes to Hubert issues that cause problems later * Trying again with Conv1D/SeparableConv fixes * Revert "Trying again with Conv1D/SeparableConv fixes" This reverts commit 55092bc. * Apply the build shape fixes to Wav2Vec2 as well * One more attempt! * Revert "One more attempt!" This reverts commit 5ac3e4c. * Another attempt! * Revert "Another attempt!" This reverts commit ea16d89. * Let's see how many failures we get without the internal build method * Fix OpenAI * Fix MobileBERT * (Mostly) fix GroupVIT * Fix BLIP * One more BLIP fix * One more BLIP fix! * Fix Regnet * Finally fully fix GroupViT * Fix Data2Vec and add the new AdaptivePool * Fix Segformer * Fix Albert * Fix Deberta/DebertaV2 * Fix XLM * Actually fix XLM * Fix Flaubert * Fix lxmert * Fix Resnet * Fix ConvBERT * Fix ESM * Fix Convnext / ConvnextV2 * Fix SAM * Fix Efficientformer * Fix LayoutLMv3 * Fix speech_to_text * Fix mpnet and mobilevit * Fix Swin * Fix CTRL * Fix CVT * Fix DPR * Fix Wav2Vec2 * Fix T5 * Fix Hubert * Fix GPT2 * Fix Whisper * Fix DeiT * Fix the encoder-decoder / dual-encoder classes * make fix-copies * build in name scope * Fix summarization test * Fix tied weight names for BART + Blenderbot * Fix tied weight name building * Fix to TFESM weight building * Update TF SAM * Expand all the shapes out into Big Boy Shapes
1 parent 0fed337 commit 3118bc0

File tree

73 files changed

+11039
-503
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

73 files changed

+11039
-503
lines changed

src/transformers/modeling_tf_utils.py

+25-18
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
from huggingface_hub import Repository, list_repo_files
3636
from keras import backend as K
3737
from packaging.version import parse
38-
from tensorflow.python.util.keras_deps import get_call_context_function
3938

4039
from . import DataCollatorWithPadding, DefaultDataCollator
4140
from .activations_tf import get_tf_activation
@@ -1122,6 +1121,10 @@ def dummy_inputs(self) -> Dict[str, tf.Tensor]:
11221121
)
11231122
return dummies
11241123

1124+
def build_in_name_scope(self):
1125+
with tf.name_scope(self.name):
1126+
self.build(input_shape=None)
1127+
11251128
@property
11261129
def framework(self) -> str:
11271130
"""
@@ -1130,15 +1133,7 @@ def framework(self) -> str:
11301133
return "tf"
11311134

11321135
def build(self, input_shape=None):
1133-
call_context = get_call_context_function()
1134-
if self.built or call_context().in_call:
1135-
self.built = True
1136-
else:
1137-
self.built = True
1138-
# Set the serving spec quickly to ensure that Keras doesn't use the specific dummy input shapes as the spec
1139-
# Setting it in build() allows users to override the shape when loading a non-pretrained model from config
1140-
self._set_save_spec(self.input_signature)
1141-
self(self.dummy_inputs, training=False)
1136+
pass # This is just here to make sure we don't call the superclass build()
11421137

11431138
def __init__(self, config, *inputs, **kwargs):
11441139
super().__init__(*inputs, **kwargs)
@@ -1869,7 +1864,7 @@ def set_input_embeddings(self, value):
18691864
main_layer.set_input_embeddings(value)
18701865
except AttributeError:
18711866
logger.info("Building the model")
1872-
self.build()
1867+
self.build_in_name_scope()
18731868
main_layer.set_input_embeddings(value)
18741869

18751870
def get_output_embeddings(self) -> Union[None, tf.keras.layers.Layer]:
@@ -1886,7 +1881,7 @@ def get_output_embeddings(self) -> Union[None, tf.keras.layers.Layer]:
18861881
return lm_head.get_output_embeddings()
18871882
except AttributeError:
18881883
logger.info("Building the model")
1889-
self.build()
1884+
self.build_in_name_scope()
18901885

18911886
return lm_head().get_output_embeddings()
18921887

@@ -1906,7 +1901,7 @@ def set_output_embeddings(self, value):
19061901
lm_head.set_output_embeddings(value)
19071902
except AttributeError:
19081903
logger.info("Building the model")
1909-
self.build()
1904+
self.build_in_name_scope()
19101905
lm_head.set_output_embeddings(value)
19111906

19121907
def get_output_layer_with_bias(self) -> Union[None, tf.keras.layers.Layer]:
@@ -1944,7 +1939,7 @@ def get_bias(self) -> Union[None, Dict[str, tf.Variable]]:
19441939
try:
19451940
return lm_head.get_bias()
19461941
except AttributeError:
1947-
self.build()
1942+
self.build_in_name_scope()
19481943

19491944
return lm_head.get_bias()
19501945
return None
@@ -1962,7 +1957,7 @@ def set_bias(self, value):
19621957
try:
19631958
lm_head.set_bias(value)
19641959
except AttributeError:
1965-
self.build()
1960+
self.build_in_name_scope()
19661961
lm_head.set_bias(value)
19671962

19681963
def get_lm_head(self) -> tf.keras.layers.Layer:
@@ -2049,7 +2044,7 @@ def _get_word_embedding_weight(model, embedding_layer):
20492044
# The reason why the attributes don't exist might be
20502045
# because the model is not built, so retry getting
20512046
# the argument after building the model
2052-
model.build()
2047+
model.build_in_name_scope()
20532048

20542049
embeds = getattr(embedding_layer, "weight", None)
20552050
if embeds is not None:
@@ -2914,9 +2909,9 @@ def from_pretrained(
29142909
# we might need to extend the variable scope for composite models
29152910
if load_weight_prefix is not None:
29162911
with tf.compat.v1.variable_scope(load_weight_prefix):
2917-
model.build() # build the network with dummy inputs
2912+
model.build_in_name_scope() # build the network with dummy inputs
29182913
else:
2919-
model.build() # build the network with dummy inputs
2914+
model.build_in_name_scope() # build the network with dummy inputs
29202915

29212916
if safetensors_from_pt:
29222917
from .modeling_tf_pytorch_utils import load_pytorch_state_dict_in_tf2_model
@@ -3215,6 +3210,9 @@ def __init__(self, nf, nx, initializer_range=0.02, **kwargs):
32153210
self.initializer_range = initializer_range
32163211

32173212
def build(self, input_shape):
3213+
if self.built:
3214+
return
3215+
self.built = True
32183216
self.weight = self.add_weight(
32193217
"weight", shape=[self.nx, self.nf], initializer=get_initializer(self.initializer_range)
32203218
)
@@ -3398,6 +3396,7 @@ def __init__(self, config: PretrainedConfig, initializer_range: float = 0.02, **
33983396
self.has_last_dropout = hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0
33993397
if self.has_last_dropout:
34003398
self.last_dropout = tf.keras.layers.Dropout(config.summary_last_dropout)
3399+
self.hidden_size = config.hidden_size
34013400

34023401
def call(self, inputs, cls_index=None, training=False):
34033402
if not isinstance(inputs, (dict, tuple, list)):
@@ -3450,6 +3449,14 @@ def call(self, inputs, cls_index=None, training=False):
34503449

34513450
return output
34523451

3452+
def build(self, input_shape):
3453+
if self.built:
3454+
return
3455+
self.built = True
3456+
if getattr(self, "summary", None) is not None:
3457+
with tf.name_scope("summary"):
3458+
self.summary.build(self.hidden_size)
3459+
34533460

34543461
def get_initializer(initializer_range: float = 0.02) -> tf.keras.initializers.TruncatedNormal:
34553462
"""

0 commit comments

Comments
 (0)