@@ -231,6 +231,10 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
231
231
const size_t batch_start = interleaved_start / num_queries;
232
232
const size_t num_interleaved = num_tokens * num_queries;
233
233
234
+ // Self extend
235
+ constexpr size_t ngb_size = TConfig::self_extend_ngb_size;
236
+ constexpr size_t grp_size = TConfig::self_extend_grp_size;
237
+
234
238
// For the computation of Q, K, and V, it is useful to remember that
235
239
// qkv_einsum_w has shape [(kHeads + kKVHeads * 2), kKQVDim, kModelDim]
236
240
// and kQStride = kQKVDim * (kIsMHA ? 3 : 1);
@@ -286,12 +290,17 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
286
290
const size_t interleaved_idx = task / kKVHeads ;
287
291
const size_t query_idx = interleaved_idx % num_queries;
288
292
const size_t batch_idx = interleaved_idx / num_queries;
289
- const size_t pos = batch_start + batch_idx;
293
+ size_t pos = batch_start + batch_idx;
290
294
const size_t cache_pos = div_seq_len.Remainder (pos);
291
295
const size_t kv_offset = cache_pos * kCachePosSize +
292
296
layer * kCacheLayerSize + head * kQKVDim * 2 ;
293
297
KVCache& kv_cache = kv_caches[query_idx];
294
298
float * HWY_RESTRICT kv = kv_cache.kv_cache .get () + kv_offset;
299
+
300
+ // When embedding position, we will use grouped key position
301
+ if (pos > ngb_size && TConfig::kSelfExtend ) {
302
+ pos /= grp_size;
303
+ }
295
304
if constexpr (kIsMHA ) {
296
305
// For MHA, copy KV into the KV cache from scratch space (see above).
297
306
const float * HWY_RESTRICT q =
@@ -321,7 +330,13 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
321
330
activations.q .Batch (interleaved_idx) + head * kQStride ;
322
331
323
332
// Apply rope and scaling to Q.
324
- const size_t pos = batch_start + batch_idx;
333
+ size_t pos = batch_start + batch_idx;
334
+ if (pos > ngb_size && TConfig::kSelfExtend ) {
335
+ const grp_pos = pos / grp_size;
336
+ const shift = ngb_size - ngb_size / grp_size
337
+ const shifted_grouped_pos = grp_pos + shift
338
+ pos = shifted_grouped_pos;
339
+ }
325
340
PostQK<TConfig>(q, pos, layer);
326
341
MulByConst (kQueryScale , q, kQKVDim );
327
342
0 commit comments