Skip to content

Commit 7e487c2

Browse files
authored
Support QK norm in static attention
Differential Revision: D72401511 Pull Request resolved: #9879
1 parent e9c2315 commit 7e487c2

File tree

2 files changed

+24
-4
lines changed

2 files changed

+24
-4
lines changed

Diff for: examples/models/llama/static_attention.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,6 @@ def __init__(self, config: ModelArgs, layer_id: int, rope: Rope):
212212
self.use_qk_norm = config.use_qk_norm
213213
self.use_conv2d = False
214214

215-
assert not self.use_qk_norm, "QK norm not supported in static attention yet"
216215
self.wqs = nn.ModuleList(
217216
[
218217
nn.Linear(self.dim, self.head_dim, bias=self.attention_qkv_bias)
@@ -241,6 +240,13 @@ def __init__(self, config: ModelArgs, layer_id: int, rope: Rope):
241240
self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False)
242241
self.rope = _Rope(rope.params.use_hf_rope)
243242

243+
if self.use_qk_norm:
244+
self.q_norm = torch.nn.RMSNorm(self.head_dim, config.norm_eps)
245+
self.k_norm = torch.nn.RMSNorm(self.head_dim, config.norm_eps)
246+
else:
247+
self.q_norm = torch.nn.Identity()
248+
self.k_norm = torch.nn.Identity()
249+
244250
def forward(
245251
self,
246252
x: torch.Tensor,
@@ -275,6 +281,10 @@ def from_conv2ds(ts):
275281
new_ks = from_conv2ds(new_ks)
276282
new_vs = from_conv2ds(new_vs)
277283

284+
if self.use_qk_norm:
285+
new_qs = [self.q_norm(q) for q in new_qs]
286+
new_ks = [self.k_norm(k) for k in new_ks]
287+
278288
new_qs = [self.rope(q, freqs_cos, freqs_sin) for q in new_qs]
279289
new_ks = [self.rope(k, freqs_cos, freqs_sin) for k in new_ks]
280290
all_ks = []
@@ -325,6 +335,13 @@ def load_weights_from_attention_mha(self, other: AttentionMHA):
325335

326336
self.wo.weight.data.copy_(other.wo.weight)
327337

338+
if other.use_qk_norm:
339+
self.use_qk_norm = True
340+
self.q_norm = torch.nn.RMSNorm(other.q_norm_fn.dim, other.q_norm_fn.eps)
341+
self.q_norm.load_state_dict(other.q_norm_fn.state_dict())
342+
self.k_norm = torch.nn.RMSNorm(other.k_norm_fn.dim, other.k_norm_fn.eps)
343+
self.k_norm.load_state_dict(other.k_norm_fn.state_dict())
344+
328345
def linear_to_conv2d(self):
329346
def transfer_weight(linear, conv2d):
330347
conv2d.weight.data.copy_(linear.weight[:, :, None, None])

Diff for: examples/models/llama/tests/test_static_attention.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,13 @@ def setUp(self):
1717
torch.manual_seed(42)
1818

1919
def test_without_cache(self):
20-
def test(use_conv2d):
20+
def test(use_qk_norm, use_conv2d):
2121
config = ModelArgs(
2222
dim=64,
2323
n_heads=4,
2424
n_kv_heads=2,
2525
max_seq_len=8,
26+
use_qk_norm=use_qk_norm,
2627
)
2728
layer_id = 0
2829
rope = Rope(config)
@@ -47,8 +48,10 @@ def test(use_conv2d):
4748
)
4849
self.assertTrue(torch.isclose(y, expected, rtol=1e-3).all())
4950

50-
test(True)
51-
test(False)
51+
test(True, True)
52+
test(True, False)
53+
test(False, True)
54+
test(False, False)
5255

5356
def test_hf_rope_without_cache(self):
5457
config = ModelArgs(

0 commit comments

Comments
 (0)