Skip to content

Commit 5a2a7ee

Browse files
committed
init push
1 parent 1982a6b commit 5a2a7ee

File tree

1 file changed

+17
-2
lines changed

1 file changed

+17
-2
lines changed

gemma/gemma-inl.h

+17-2
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,10 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
231231
const size_t batch_start = interleaved_start / num_queries;
232232
const size_t num_interleaved = num_tokens * num_queries;
233233

234+
// Self extend
235+
constexpr size_t ngb_size = TConfig::self_extend_ngb_size;
236+
constexpr size_t grp_size = TConfig::self_extend_grp_size;
237+
234238
// For the computation of Q, K, and V, it is useful to remember that
235239
// qkv_einsum_w has shape [(kHeads + kKVHeads * 2), kKQVDim, kModelDim]
236240
// and kQStride = kQKVDim * (kIsMHA ? 3 : 1);
@@ -286,12 +290,17 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
286290
const size_t interleaved_idx = task / kKVHeads;
287291
const size_t query_idx = interleaved_idx % num_queries;
288292
const size_t batch_idx = interleaved_idx / num_queries;
289-
const size_t pos = batch_start + batch_idx;
293+
size_t pos = batch_start + batch_idx;
290294
const size_t cache_pos = div_seq_len.Remainder(pos);
291295
const size_t kv_offset = cache_pos * kCachePosSize +
292296
layer * kCacheLayerSize + head * kQKVDim * 2;
293297
KVCache& kv_cache = kv_caches[query_idx];
294298
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
299+
300+
// When embedding position, we will use grouped key position
301+
if (pos > ngb_size && TConfig::kSelfExtend) {
302+
pos /= grp_size;
303+
}
295304
if constexpr (kIsMHA) {
296305
// For MHA, copy KV into the KV cache from scratch space (see above).
297306
const float* HWY_RESTRICT q =
@@ -321,7 +330,13 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
321330
activations.q.Batch(interleaved_idx) + head * kQStride;
322331

323332
// Apply rope and scaling to Q.
324-
const size_t pos = batch_start + batch_idx;
333+
size_t pos = batch_start + batch_idx;
334+
if (pos > ngb_size && TConfig::kSelfExtend) {
335+
const grp_pos = pos / grp_size;
336+
const shift = ngb_size - ngb_size / grp_size
337+
const shifted_grouped_pos = grp_pos + shift
338+
pos = shifted_grouped_pos;
339+
}
325340
PostQK<TConfig>(q, pos, layer);
326341
MulByConst(kQueryScale, q, kQKVDim);
327342

0 commit comments

Comments
 (0)