@@ -212,7 +212,6 @@ def __init__(self, config: ModelArgs, layer_id: int, rope: Rope):
212
212
self .use_qk_norm = config .use_qk_norm
213
213
self .use_conv2d = False
214
214
215
- assert not self .use_qk_norm , "QK norm not supported in static attention yet"
216
215
self .wqs = nn .ModuleList (
217
216
[
218
217
nn .Linear (self .dim , self .head_dim , bias = self .attention_qkv_bias )
@@ -241,6 +240,13 @@ def __init__(self, config: ModelArgs, layer_id: int, rope: Rope):
241
240
self .wo = nn .Linear (self .n_heads * self .head_dim , self .dim , bias = False )
242
241
self .rope = _Rope (rope .params .use_hf_rope )
243
242
243
+ if self .use_qk_norm :
244
+ self .q_norm = torch .nn .RMSNorm (self .head_dim , config .norm_eps )
245
+ self .k_norm = torch .nn .RMSNorm (self .head_dim , config .norm_eps )
246
+ else :
247
+ self .q_norm = torch .nn .Identity ()
248
+ self .k_norm = torch .nn .Identity ()
249
+
244
250
def forward (
245
251
self ,
246
252
x : torch .Tensor ,
@@ -275,6 +281,10 @@ def from_conv2ds(ts):
275
281
new_ks = from_conv2ds (new_ks )
276
282
new_vs = from_conv2ds (new_vs )
277
283
284
+ if self .use_qk_norm :
285
+ new_qs = [self .q_norm (q ) for q in new_qs ]
286
+ new_ks = [self .k_norm (k ) for k in new_ks ]
287
+
278
288
new_qs = [self .rope (q , freqs_cos , freqs_sin ) for q in new_qs ]
279
289
new_ks = [self .rope (k , freqs_cos , freqs_sin ) for k in new_ks ]
280
290
all_ks = []
@@ -325,6 +335,13 @@ def load_weights_from_attention_mha(self, other: AttentionMHA):
325
335
326
336
self .wo .weight .data .copy_ (other .wo .weight )
327
337
338
+ if other .use_qk_norm :
339
+ self .use_qk_norm = True
340
+ self .q_norm = torch .nn .RMSNorm (other .q_norm_fn .dim , other .q_norm_fn .eps )
341
+ self .q_norm .load_state_dict (other .q_norm_fn .state_dict ())
342
+ self .k_norm = torch .nn .RMSNorm (other .k_norm_fn .dim , other .k_norm_fn .eps )
343
+ self .k_norm .load_state_dict (other .k_norm_fn .state_dict ())
344
+
328
345
def linear_to_conv2d (self ):
329
346
def transfer_weight (linear , conv2d ):
330
347
conv2d .weight .data .copy_ (linear .weight [:, :, None , None ])
0 commit comments