Skip to content

Commit ca07202

Browse files
committed
TF refactor that we'll need later
1 parent 02a1f4c commit ca07202

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

src/transformers/models/t5/modeling_tf_t5.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -75,16 +75,17 @@
7575

7676

7777
class TFT5LayerNorm(tf.keras.layers.Layer):
78-
def __init__(self, epsilon=1e-6, **kwargs):
78+
def __init__(self, hidden_size, epsilon=1e-6, **kwargs):
7979
"""
8080
Construct a layernorm module in the T5 style No bias and no subtraction of mean.
8181
"""
8282
super().__init__(**kwargs)
8383
self.variance_epsilon = epsilon
84+
self.hidden_size = hidden_size
8485

8586
def build(self, input_shape):
8687
"""Build shared word embedding layer"""
87-
self.weight = self.add_weight("weight", shape=(input_shape[-1],), initializer="ones")
88+
self.weight = self.add_weight("weight", shape=(self.hidden_size,), initializer="ones")
8889
super().build(input_shape)
8990

9091
def call(self, hidden_states):
@@ -157,7 +158,7 @@ def __init__(self, config, **kwargs):
157158
else:
158159
self.DenseReluDense = TFT5DenseActDense(config, name="DenseReluDense")
159160

160-
self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm")
161+
self.layer_norm = TFT5LayerNorm(config.d_model, epsilon=config.layer_norm_epsilon, name="layer_norm")
161162
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
162163

163164
def call(self, hidden_states, training=False):
@@ -439,7 +440,7 @@ def __init__(self, config, has_relative_attention_bias=False, **kwargs):
439440
has_relative_attention_bias=has_relative_attention_bias,
440441
name="SelfAttention",
441442
)
442-
self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm")
443+
self.layer_norm = TFT5LayerNorm(config.d_model, epsilon=config.layer_norm_epsilon, name="layer_norm")
443444
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
444445

445446
def call(
@@ -477,7 +478,7 @@ def __init__(self, config, **kwargs):
477478
has_relative_attention_bias=False,
478479
name="EncDecAttention",
479480
)
480-
self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm")
481+
self.layer_norm = TFT5LayerNorm(config.d_model, epsilon=config.layer_norm_epsilon, name="layer_norm")
481482
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
482483

483484
def call(
@@ -640,7 +641,7 @@ def __init__(self, config, embed_tokens=None, **kwargs):
640641
TFT5Block(config, has_relative_attention_bias=bool(i == 0), name=f"block_._{i}")
641642
for i in range(config.num_layers)
642643
]
643-
self.final_layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="final_layer_norm")
644+
self.final_layer_norm = TFT5LayerNorm(config.d_model, epsilon=config.layer_norm_epsilon, name="final_layer_norm")
644645
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
645646

646647
def _prune_heads(self, heads_to_prune):

0 commit comments

Comments
 (0)