Skip to content

Commit 5984576

Browse files
authored
Merge pull request #13 from stanford-crfm/openai-initialization
OpenAI GPT-2 Initialization
2 parents 3a71d8d + 9fd657c commit 5984576

File tree

1 file changed

+15
-15
lines changed

1 file changed

+15
-15
lines changed

src/transformers/models/gpt2/modeling_gpt2.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# limitations under the License.
1616
"""PyTorch OpenAI GPT-2 model."""
1717

18+
import math
1819
import os
1920
from dataclasses import dataclass
2021
from typing import Optional, Tuple
@@ -188,13 +189,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
188189
bsz, num_heads, seq_len, dk = query.size()
189190

190191
# Preallocate attn_weights for `baddbmm`
191-
attn_weights = torch.empty(
192-
bsz * num_heads,
193-
seq_len,
194-
seq_len,
195-
dtype=torch.float32,
196-
device=query.device
197-
)
192+
attn_weights = torch.empty(bsz * num_heads, seq_len, seq_len, dtype=torch.float32, device=query.device)
198193

199194
# Compute Scale Factor
200195
scale_factor = 1.0
@@ -207,13 +202,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
207202
# Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
208203
with autocast(enabled=False):
209204
q, k = query.reshape(-1, seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, seq_len)
210-
attn_weights = torch.baddbmm(
211-
attn_weights,
212-
q.float(),
213-
k.float(),
214-
beta=0,
215-
alpha=scale_factor
216-
)
205+
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
217206
attn_weights = attn_weights.reshape(bsz, num_heads, seq_len, seq_len)
218207

219208
else:
@@ -442,6 +431,17 @@ def _init_weights(self, module):
442431
module.bias.data.zero_()
443432
module.weight.data.fill_(1.0)
444433

434+
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
435+
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
436+
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
437+
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
438+
#
439+
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
440+
for name, p in module.named_parameters():
441+
if "c_proj" in name and "weight" in name:
442+
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
443+
p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)))
444+
445445

446446
@dataclass
447447
class GPT2DoubleHeadsModelOutput(ModelOutput):
@@ -629,7 +629,7 @@ def __init__(self, config):
629629
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
630630

631631
self.drop = nn.Dropout(config.embd_pdrop)
632-
self.h = nn.ModuleList([GPT2Block(config, layer_idx=i+1) for i in range(config.num_hidden_layers)])
632+
self.h = nn.ModuleList([GPT2Block(config, layer_idx=i + 1) for i in range(config.num_hidden_layers)])
633633
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
634634

635635
self.init_weights()

0 commit comments

Comments
 (0)