|
75 | 75 |
|
76 | 76 |
|
77 | 77 | class TFT5LayerNorm(tf.keras.layers.Layer):
|
78 |
| - def __init__(self, epsilon=1e-6, **kwargs): |
| 78 | + def __init__(self, hidden_size, epsilon=1e-6, **kwargs): |
79 | 79 | """
|
80 | 80 | Construct a layernorm module in the T5 style No bias and no subtraction of mean.
|
81 | 81 | """
|
82 | 82 | super().__init__(**kwargs)
|
83 | 83 | self.variance_epsilon = epsilon
|
| 84 | + self.hidden_size = hidden_size |
84 | 85 |
|
85 | 86 | def build(self, input_shape):
|
86 | 87 | """Build shared word embedding layer"""
|
87 |
| - self.weight = self.add_weight("weight", shape=(input_shape[-1],), initializer="ones") |
| 88 | + self.weight = self.add_weight("weight", shape=(self.hidden_size,), initializer="ones") |
88 | 89 | super().build(input_shape)
|
89 | 90 |
|
90 | 91 | def call(self, hidden_states):
|
@@ -157,7 +158,7 @@ def __init__(self, config, **kwargs):
|
157 | 158 | else:
|
158 | 159 | self.DenseReluDense = TFT5DenseActDense(config, name="DenseReluDense")
|
159 | 160 |
|
160 |
| - self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm") |
| 161 | + self.layer_norm = TFT5LayerNorm(config.d_model, epsilon=config.layer_norm_epsilon, name="layer_norm") |
161 | 162 | self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
|
162 | 163 |
|
163 | 164 | def call(self, hidden_states, training=False):
|
@@ -439,7 +440,7 @@ def __init__(self, config, has_relative_attention_bias=False, **kwargs):
|
439 | 440 | has_relative_attention_bias=has_relative_attention_bias,
|
440 | 441 | name="SelfAttention",
|
441 | 442 | )
|
442 |
| - self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm") |
| 443 | + self.layer_norm = TFT5LayerNorm(config.d_model, epsilon=config.layer_norm_epsilon, name="layer_norm") |
443 | 444 | self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
|
444 | 445 |
|
445 | 446 | def call(
|
@@ -477,7 +478,7 @@ def __init__(self, config, **kwargs):
|
477 | 478 | has_relative_attention_bias=False,
|
478 | 479 | name="EncDecAttention",
|
479 | 480 | )
|
480 |
| - self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm") |
| 481 | + self.layer_norm = TFT5LayerNorm(config.d_model, epsilon=config.layer_norm_epsilon, name="layer_norm") |
481 | 482 | self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
|
482 | 483 |
|
483 | 484 | def call(
|
@@ -640,7 +641,7 @@ def __init__(self, config, embed_tokens=None, **kwargs):
|
640 | 641 | TFT5Block(config, has_relative_attention_bias=bool(i == 0), name=f"block_._{i}")
|
641 | 642 | for i in range(config.num_layers)
|
642 | 643 | ]
|
643 |
| - self.final_layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="final_layer_norm") |
| 644 | + self.final_layer_norm = TFT5LayerNorm(config.d_model, epsilon=config.layer_norm_epsilon, name="final_layer_norm") |
644 | 645 | self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
|
645 | 646 |
|
646 | 647 | def _prune_heads(self, heads_to_prune):
|
|
0 commit comments