15
15
# limitations under the License.
16
16
"""PyTorch OpenAI GPT-2 model."""
17
17
18
+ import math
18
19
import os
19
20
from dataclasses import dataclass
20
21
from typing import Optional , Tuple
@@ -188,13 +189,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
188
189
bsz , num_heads , seq_len , dk = query .size ()
189
190
190
191
# Preallocate attn_weights for `baddbmm`
191
- attn_weights = torch .empty (
192
- bsz * num_heads ,
193
- seq_len ,
194
- seq_len ,
195
- dtype = torch .float32 ,
196
- device = query .device
197
- )
192
+ attn_weights = torch .empty (bsz * num_heads , seq_len , seq_len , dtype = torch .float32 , device = query .device )
198
193
199
194
# Compute Scale Factor
200
195
scale_factor = 1.0
@@ -207,13 +202,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
207
202
# Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
208
203
with autocast (enabled = False ):
209
204
q , k = query .reshape (- 1 , seq_len , dk ), key .transpose (- 1 , - 2 ).reshape (- 1 , dk , seq_len )
210
- attn_weights = torch .baddbmm (
211
- attn_weights ,
212
- q .float (),
213
- k .float (),
214
- beta = 0 ,
215
- alpha = scale_factor
216
- )
205
+ attn_weights = torch .baddbmm (attn_weights , q .float (), k .float (), beta = 0 , alpha = scale_factor )
217
206
attn_weights = attn_weights .reshape (bsz , num_heads , seq_len , seq_len )
218
207
219
208
else :
@@ -442,6 +431,17 @@ def _init_weights(self, module):
442
431
module .bias .data .zero_ ()
443
432
module .weight .data .fill_ (1.0 )
444
433
434
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
435
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
436
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
437
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
438
+ #
439
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
440
+ for name , p in module .named_parameters ():
441
+ if "c_proj" in name and "weight" in name :
442
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
443
+ p .data .normal_ (mean = 0.0 , std = (self .config .initializer_range / math .sqrt (2 * self .config .n_layer )))
444
+
445
445
446
446
@dataclass
447
447
class GPT2DoubleHeadsModelOutput (ModelOutput ):
@@ -629,7 +629,7 @@ def __init__(self, config):
629
629
self .wpe = nn .Embedding (config .max_position_embeddings , self .embed_dim )
630
630
631
631
self .drop = nn .Dropout (config .embd_pdrop )
632
- self .h = nn .ModuleList ([GPT2Block (config , layer_idx = i + 1 ) for i in range (config .num_hidden_layers )])
632
+ self .h = nn .ModuleList ([GPT2Block (config , layer_idx = i + 1 ) for i in range (config .num_hidden_layers )])
633
633
self .ln_f = nn .LayerNorm (self .embed_dim , eps = config .layer_norm_epsilon )
634
634
635
635
self .init_weights ()
0 commit comments