22
22
import torch
23
23
import torch .utils .checkpoint
24
24
from torch import nn
25
+ from torch .cuda .amp import autocast
25
26
from torch .nn import CrossEntropyLoss , MSELoss
26
27
27
28
from ...activations import ACT2FN
@@ -124,7 +125,7 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
124
125
125
126
126
127
class GPT2Attention (nn .Module ):
127
- def __init__ (self , config , is_cross_attention = False ):
128
+ def __init__ (self , config , is_cross_attention = False , layer_idx = None ):
128
129
super ().__init__ ()
129
130
130
131
max_positions = config .max_position_embeddings
@@ -148,6 +149,11 @@ def __init__(self, config, is_cross_attention=False):
148
149
self .scale_attn_weights = config .scale_attn_weights
149
150
self .is_cross_attention = is_cross_attention
150
151
152
+ # [Required for Mistral-GPT2] Layer-wise attention scaling, reordering, and upcasting
153
+ self .scale_attn_by_layer = config .scale_attn_by_layer
154
+ self .layer_idx = layer_idx
155
+ self .reorder_and_upcast_attn = config .reorder_and_upcast_attn
156
+
151
157
if self .is_cross_attention :
152
158
self .c_attn = Conv1D (2 * self .embed_dim , self .embed_dim )
153
159
self .q_attn = Conv1D (self .embed_dim , self .embed_dim )
@@ -176,10 +182,49 @@ def prune_heads(self, heads):
176
182
self .pruned_heads = self .pruned_heads .union (heads )
177
183
178
184
def _attn (self , query , key , value , attention_mask = None , head_mask = None ):
179
- attn_weights = torch .matmul (query , key .transpose (- 1 , - 2 ))
180
185
181
- if self .scale_attn_weights :
182
- attn_weights = attn_weights / (float (value .size (- 1 )) ** 0.5 )
186
+ if self .reorder_and_upcast_attn :
187
+ # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
188
+ bsz , num_heads , seq_len , dk = query .size ()
189
+
190
+ # Preallocate attn_weights for `baddbmm`
191
+ attn_weights = torch .empty (
192
+ bsz * num_heads ,
193
+ seq_len ,
194
+ seq_len ,
195
+ dtype = torch .float32 ,
196
+ device = query .device
197
+ )
198
+
199
+ # Compute Scale Factor
200
+ scale_factor = 1.0
201
+ if self .scale_attn_weights :
202
+ scale_factor /= float (value .size (- 1 )) ** 0.5
203
+
204
+ if self .scale_attn_by_layer :
205
+ scale_factor /= float (self .layer_idx )
206
+
207
+ # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
208
+ with autocast (enabled = False ):
209
+ q , k = query .reshape (- 1 , seq_len , dk ), key .transpose (- 1 , - 2 ).reshape (- 1 , dk , seq_len )
210
+ attn_weights = torch .baddbmm (
211
+ attn_weights ,
212
+ q .float (),
213
+ k .float (),
214
+ beta = 0 ,
215
+ alpha = scale_factor
216
+ )
217
+ attn_weights = attn_weights .reshape (bsz , num_heads , seq_len , seq_len )
218
+
219
+ else :
220
+ attn_weights = torch .matmul (query , key .transpose (- 1 , - 2 ))
221
+
222
+ if self .scale_attn_weights :
223
+ attn_weights = attn_weights / (float (value .size (- 1 )) ** 0.5 )
224
+
225
+ # [Required for Mistral-GPT2] Layer-wise attention scaling
226
+ if self .scale_attn_by_layer :
227
+ attn_weights = attn_weights / float (self .layer_idx )
183
228
184
229
if not self .is_cross_attention :
185
230
# if only "normal" attention layer implements causal mask
@@ -192,6 +237,9 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
192
237
attn_weights = attn_weights + attention_mask
193
238
194
239
attn_weights = nn .Softmax (dim = - 1 )(attn_weights )
240
+
241
+ # Downcast (if necessary) back to V dtype (half/fp16 if mixed-precision) -- No-Op if in float()
242
+ attn_weights = attn_weights .type (value .dtype )
195
243
attn_weights = self .attn_dropout (attn_weights )
196
244
197
245
# Mask heads if we want to
@@ -287,13 +335,13 @@ def forward(self, hidden_states):
287
335
288
336
289
337
class GPT2Block (nn .Module ):
290
- def __init__ (self , config ):
338
+ def __init__ (self , config , layer_idx = None ):
291
339
super ().__init__ ()
292
340
hidden_size = config .hidden_size
293
341
inner_dim = config .n_inner if config .n_inner is not None else 4 * hidden_size
294
342
295
343
self .ln_1 = nn .LayerNorm (hidden_size , eps = config .layer_norm_epsilon )
296
- self .attn = GPT2Attention (config )
344
+ self .attn = GPT2Attention (config , layer_idx = layer_idx )
297
345
self .ln_2 = nn .LayerNorm (hidden_size , eps = config .layer_norm_epsilon )
298
346
299
347
if config .add_cross_attention :
@@ -581,7 +629,7 @@ def __init__(self, config):
581
629
self .wpe = nn .Embedding (config .max_position_embeddings , self .embed_dim )
582
630
583
631
self .drop = nn .Dropout (config .embd_pdrop )
584
- self .h = nn .ModuleList ([GPT2Block (config ) for _ in range (config .num_hidden_layers )])
632
+ self .h = nn .ModuleList ([GPT2Block (config , layer_idx = i + 1 ) for i in range (config .num_hidden_layers )])
585
633
self .ln_f = nn .LayerNorm (self .embed_dim , eps = config .layer_norm_epsilon )
586
634
587
635
self .init_weights ()
0 commit comments