Skip to content

Commit 3a71d8d

Browse files
authored
Merge pull request #12 from stanford-crfm/upcast-scaling
Add Layer Scaling & Upcast/Reordering Flags + Functionality
2 parents 7bd16b8 + 53d145a commit 3a71d8d

File tree

2 files changed

+65
-7
lines changed

2 files changed

+65
-7
lines changed

src/transformers/models/gpt2/configuration_gpt2.py

+10
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,12 @@ class GPT2Config(PretrainedConfig):
113113
Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.
114114
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
115115
Whether or not the model should return the last key/values attentions (not used by all models).
116+
scale_attn_by_layer (:obj:`bool`, `optional`, defaults to :obj:`False):
117+
[Mistral-GPT2] Whether to additionally scale attention weights by 1 / layer_idx.
118+
reorder_attn (:obj:`bool`, `optional`, defaults to :obj:`False`):
119+
[Mistral-GPT2] Whether to scale keys (K) prior to computing attention (dot-product)
120+
upscale_attn (:obj:`bool`, `optional`, defaults to :obj:`False`):
121+
[Mistral-GPT2] Whether to upcast attention dot-product/softmax to float() when training with mixed precision
116122
117123
Example::
118124
@@ -162,6 +168,8 @@ def __init__(
162168
use_cache=True,
163169
bos_token_id=50256,
164170
eos_token_id=50256,
171+
scale_attn_by_layer=False,
172+
reorder_and_upcast_attn=False,
165173
**kwargs
166174
):
167175
self.vocab_size = vocab_size
@@ -185,6 +193,8 @@ def __init__(
185193
self.gradient_checkpointing = gradient_checkpointing
186194
self.scale_attn_weights = scale_attn_weights
187195
self.use_cache = use_cache
196+
self.scale_attn_by_layer = scale_attn_by_layer
197+
self.reorder_and_upcast_attn = reorder_and_upcast_attn
188198

189199
self.bos_token_id = bos_token_id
190200
self.eos_token_id = eos_token_id

src/transformers/models/gpt2/modeling_gpt2.py

+55-7
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import torch
2323
import torch.utils.checkpoint
2424
from torch import nn
25+
from torch.cuda.amp import autocast
2526
from torch.nn import CrossEntropyLoss, MSELoss
2627

2728
from ...activations import ACT2FN
@@ -124,7 +125,7 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
124125

125126

126127
class GPT2Attention(nn.Module):
127-
def __init__(self, config, is_cross_attention=False):
128+
def __init__(self, config, is_cross_attention=False, layer_idx=None):
128129
super().__init__()
129130

130131
max_positions = config.max_position_embeddings
@@ -148,6 +149,11 @@ def __init__(self, config, is_cross_attention=False):
148149
self.scale_attn_weights = config.scale_attn_weights
149150
self.is_cross_attention = is_cross_attention
150151

152+
# [Required for Mistral-GPT2] Layer-wise attention scaling, reordering, and upcasting
153+
self.scale_attn_by_layer = config.scale_attn_by_layer
154+
self.layer_idx = layer_idx
155+
self.reorder_and_upcast_attn = config.reorder_and_upcast_attn
156+
151157
if self.is_cross_attention:
152158
self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim)
153159
self.q_attn = Conv1D(self.embed_dim, self.embed_dim)
@@ -176,10 +182,49 @@ def prune_heads(self, heads):
176182
self.pruned_heads = self.pruned_heads.union(heads)
177183

178184
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
179-
attn_weights = torch.matmul(query, key.transpose(-1, -2))
180185

181-
if self.scale_attn_weights:
182-
attn_weights = attn_weights / (float(value.size(-1)) ** 0.5)
186+
if self.reorder_and_upcast_attn:
187+
# Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
188+
bsz, num_heads, seq_len, dk = query.size()
189+
190+
# Preallocate attn_weights for `baddbmm`
191+
attn_weights = torch.empty(
192+
bsz * num_heads,
193+
seq_len,
194+
seq_len,
195+
dtype=torch.float32,
196+
device=query.device
197+
)
198+
199+
# Compute Scale Factor
200+
scale_factor = 1.0
201+
if self.scale_attn_weights:
202+
scale_factor /= float(value.size(-1)) ** 0.5
203+
204+
if self.scale_attn_by_layer:
205+
scale_factor /= float(self.layer_idx)
206+
207+
# Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
208+
with autocast(enabled=False):
209+
q, k = query.reshape(-1, seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, seq_len)
210+
attn_weights = torch.baddbmm(
211+
attn_weights,
212+
q.float(),
213+
k.float(),
214+
beta=0,
215+
alpha=scale_factor
216+
)
217+
attn_weights = attn_weights.reshape(bsz, num_heads, seq_len, seq_len)
218+
219+
else:
220+
attn_weights = torch.matmul(query, key.transpose(-1, -2))
221+
222+
if self.scale_attn_weights:
223+
attn_weights = attn_weights / (float(value.size(-1)) ** 0.5)
224+
225+
# [Required for Mistral-GPT2] Layer-wise attention scaling
226+
if self.scale_attn_by_layer:
227+
attn_weights = attn_weights / float(self.layer_idx)
183228

184229
if not self.is_cross_attention:
185230
# if only "normal" attention layer implements causal mask
@@ -192,6 +237,9 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
192237
attn_weights = attn_weights + attention_mask
193238

194239
attn_weights = nn.Softmax(dim=-1)(attn_weights)
240+
241+
# Downcast (if necessary) back to V dtype (half/fp16 if mixed-precision) -- No-Op if in float()
242+
attn_weights = attn_weights.type(value.dtype)
195243
attn_weights = self.attn_dropout(attn_weights)
196244

197245
# Mask heads if we want to
@@ -287,13 +335,13 @@ def forward(self, hidden_states):
287335

288336

289337
class GPT2Block(nn.Module):
290-
def __init__(self, config):
338+
def __init__(self, config, layer_idx=None):
291339
super().__init__()
292340
hidden_size = config.hidden_size
293341
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
294342

295343
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
296-
self.attn = GPT2Attention(config)
344+
self.attn = GPT2Attention(config, layer_idx=layer_idx)
297345
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
298346

299347
if config.add_cross_attention:
@@ -581,7 +629,7 @@ def __init__(self, config):
581629
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
582630

583631
self.drop = nn.Dropout(config.embd_pdrop)
584-
self.h = nn.ModuleList([GPT2Block(config) for _ in range(config.num_hidden_layers)])
632+
self.h = nn.ModuleList([GPT2Block(config, layer_idx=i+1) for i in range(config.num_hidden_layers)])
585633
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
586634

587635
self.init_weights()

0 commit comments

Comments
 (0)