Skip to content

Commit febc635

Browse files
committed
update kqv code
1 parent ca8f698 commit febc635

File tree

1 file changed

+79
-112
lines changed

1 file changed

+79
-112
lines changed

Diff for: llama.cpp

+79-112
Original file line numberDiff line numberDiff line change
@@ -5520,6 +5520,10 @@ struct llm_build_context {
55205520
inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb);
55215521
cb(inpL, "inp_embd", -1);
55225522

5523+
// inp_pos - contains the positions
5524+
struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
5525+
cb(inp_pos, "inp_pos", -1);
5526+
55235527
// KQ_scale
55245528
struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
55255529
cb(KQ_scale, "KQ_scale", -1);
@@ -5528,10 +5532,6 @@ struct llm_build_context {
55285532
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
55295533
cb(KQ_mask, "KQ_mask", -1);
55305534

5531-
// inp_pos - contains the positions
5532-
struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
5533-
cb(inp_pos, "inp_pos", -1);
5534-
55355535
// shift the entire K-cache if needed
55365536
if (do_rope_shift) {
55375537
llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb);
@@ -5544,137 +5544,104 @@ struct llm_build_context {
55445544
cur = llm_build_norm(ctx0, inpL, hparams,
55455545
model.layers[il].attn_norm, NULL,
55465546
LLM_NORM_RMS, cb, il);
5547-
cb(cur, "attention_norm_0", il);
5547+
cb(cur, "attention_norm", il);
55485548

55495549
struct ggml_tensor * attention_norm = cur;
55505550

55515551
// self-attention
55525552
{
55535553
// compute Q and K and RoPE them
5554-
struct ggml_tensor * tmpk = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
5555-
cb(tmpk, "tmpk", il);
5554+
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
5555+
cb(Qcur, "Qcur", il);
55565556

5557-
struct ggml_tensor * tmpq = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
5558-
cb(tmpq, "tmpq", il);
5557+
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
5558+
cb(Kcur, "Kcur", il);
55595559

5560-
struct ggml_tensor * Kcur = ggml_rope_custom(
5561-
ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), inp_pos,
5560+
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
5561+
cb(Vcur, "Vcur", il);
5562+
5563+
Qcur = ggml_rope_custom(
5564+
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
55625565
n_embd_head, 2, 0, n_orig_ctx, freq_base, freq_scale,
55635566
ext_factor, attn_factor, beta_fast, beta_slow);
5564-
cb(Kcur, "Kcur", il);
5567+
cb(Qcur, "Qcur", il);
55655568

5566-
struct ggml_tensor * Qcur = ggml_rope_custom(
5567-
ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens), inp_pos,
5569+
Kcur = ggml_rope_custom(
5570+
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
55685571
n_embd_head, 2, 0, n_orig_ctx, freq_base, freq_scale,
55695572
ext_factor, attn_factor, beta_fast, beta_slow);
5570-
cb(Qcur, "Qcur", il);
5573+
cb(Kcur, "Kcur", il);
55715574

5572-
// store key and value to memory
5573-
{
5574-
// compute the transposed [n_tokens, n_embd] V matrix
5575+
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
55755576

5576-
struct ggml_tensor * tmpv = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
5577-
cb(tmpv, "tmpv", il);
5577+
auto plamo_llm_build_kqv = [](
5578+
struct ggml_context * ctx,
5579+
const llama_hparams & hparams,
5580+
const llama_kv_cache & kv,
5581+
struct ggml_tensor * wo,
5582+
struct ggml_tensor * q_cur,
5583+
struct ggml_tensor * kq_mask,
5584+
int64_t n_ctx,
5585+
int32_t n_tokens,
5586+
int32_t n_kv,
5587+
const llm_build_cb & cb,
5588+
int il) {
5589+
const int64_t n_embd = hparams.n_embd;
5590+
const int64_t n_head_kv = hparams.n_head_kv;
5591+
const int64_t n_embd_head = hparams.n_embd_head();
5592+
const int64_t n_embd_gqa = hparams.n_embd_gqa();
5593+
5594+
struct ggml_tensor * q = ggml_permute(ctx, q_cur, 0, 2, 1, 3);
5595+
cb(q, "q", il);
5596+
5597+
struct ggml_tensor * k =
5598+
ggml_view_3d(ctx, kv.k_l[il],
5599+
n_embd_head, n_kv, n_head_kv,
5600+
ggml_row_size(kv.k_l[il]->type, n_embd_gqa),
5601+
ggml_row_size(kv.k_l[il]->type, n_embd_head),
5602+
0);
5603+
cb(k, "k", il);
55785604

5579-
struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd_gqa, n_tokens));
5580-
cb(Vcur, "Vcur", il);
5605+
// we should avoid to repeat K but current ggml_mul_mat generates wrong values for grouped query att
5606+
struct ggml_tensor * k_repeated = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, k->ne[0], k->ne[1], q->ne[2]);
5607+
cb(k_repeated, "k_repeated", il);
55815608

5582-
//struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k_l[il], n_tokens*n_embd_gqa, (ggml_element_size(kv_self.k_l[il])*n_embd_gqa)*(il*n_ctx + kv_head));
5583-
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k_l[il], n_tokens*n_embd_gqa, (ggml_element_size(kv_self.k_l[il])*n_embd_gqa)*kv_head);
5584-
cb(k, "k", il);
5609+
struct ggml_tensor * kq = ggml_mul_mat(ctx, ggml_repeat(ctx, k, k_repeated), q);
5610+
cb(kq, "kq", il);
5611+
5612+
kq = ggml_soft_max_ext(ctx, kq, kq_mask, 1.0f/sqrtf(float(n_embd_head)));
5613+
cb(kq, "kq_soft_max_ext", il);
55855614

5586-
/*
5587-
struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa,
5588-
( n_ctx)*ggml_element_size(kv_self.v),
5589-
(il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v));
5590-
*/
5591-
struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v_l[il], n_tokens, n_embd_gqa,
5592-
n_ctx*ggml_element_size(kv_self.v_l[il]),
5593-
kv_head*ggml_element_size(kv_self.v_l[il]));
5615+
// split cached v into n_head heads
5616+
struct ggml_tensor * v =
5617+
ggml_view_3d(ctx, kv.v_l[il],
5618+
n_kv, n_embd_head, n_head_kv,
5619+
ggml_element_size(kv.v_l[il])*n_ctx,
5620+
ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head,
5621+
0);
55945622
cb(v, "v", il);
55955623

5596-
// important: storing RoPE-ed version of K in the KV cache!
5597-
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
5598-
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
5599-
}
5624+
// we should avoid to repeat V but current ggml_mul_mat generates wrong values for grouped query att
5625+
struct ggml_tensor * v_repeated = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, v->ne[0], v->ne[1], q->ne[2]);
5626+
cb(k_repeated, "v_repeated", il);
56005627

5601-
struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
5602-
cb(Q, "Q", il);
5628+
struct ggml_tensor * kqv = ggml_mul_mat(ctx, ggml_repeat(ctx, v, v_repeated), kq);
5629+
cb(kqv, "kqv", il);
5630+
5631+
struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3);
5632+
cb(kqv_merged, "kqv_merged", il);
56035633

5604-
/*
5605-
struct ggml_tensor * K =
5606-
ggml_view_3d(ctx0, kv_self.k,
5607-
n_embd_head, n_kv, n_head_kv,
5608-
ggml_element_size(kv_self.k)*n_embd_gqa,
5609-
ggml_element_size(kv_self.k)*n_embd_head,
5610-
ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il);
5611-
*/
5612-
struct ggml_tensor * K =
5613-
ggml_view_3d(ctx0, kv_self.k_l[il],
5614-
n_embd_head, n_kv, n_head_kv,
5615-
ggml_element_size(kv_self.k_l[il])*n_embd_gqa,
5616-
ggml_element_size(kv_self.k_l[il])*n_embd_head,
5617-
0);
5618-
cb(K, "K", il);
5619-
5620-
// K * Q
5621-
//struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
5622-
// we should avoid to repeat K but current ggml_mul_mat generates wrong values for grouped query att
5623-
struct ggml_tensor * K_repeated = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, K->ne[0], K->ne[1], Q->ne[2]);
5624-
cb(K_repeated, "K_repeated", il);
5625-
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, ggml_repeat(ctx0, K, K_repeated), Q);
5626-
cb(KQ, "KQ", il);
5627-
5628-
// KQ_scaled = KQ / sqrt(n_embd_head)
5629-
// KQ_scaled shape [n_kv, n_tokens, n_head, 1]
5630-
struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale);
5631-
cb(KQ_scaled, "KQ_scaled", il);
5632-
5633-
// KQ_masked = mask_past(KQ_scaled)
5634-
struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ_scaled, KQ_mask);
5635-
cb(KQ_masked, "KQ_masked", il);
5636-
5637-
// KQ = soft_max(KQ_masked)
5638-
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
5639-
cb(KQ_soft_max, "KQ_soft_max", il);
5640-
5641-
// split cached V into n_head heads
5642-
/*
5643-
struct ggml_tensor * V =
5644-
ggml_view_3d(ctx0, kv_self.v,
5645-
n_kv, n_embd_head, n_head_kv,
5646-
ggml_element_size(kv_self.v)*n_ctx,
5647-
ggml_element_size(kv_self.v)*n_ctx*n_embd_head,
5648-
ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il);
5649-
*/
5650-
struct ggml_tensor * V =
5651-
ggml_view_3d(ctx0, kv_self.v_l[il],
5652-
n_kv, n_embd_head, n_head_kv,
5653-
ggml_element_size(kv_self.v_l[il])*n_ctx,
5654-
ggml_element_size(kv_self.v_l[il])*n_ctx*n_embd_head,
5655-
0);
5656-
cb(V, "V", il);
5657-
5658-
//struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
5659-
// we should avoid to repeat V but current ggml_mul_mat generates wrong values for grouped query att
5660-
struct ggml_tensor * V_repeated = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, V->ne[0], V->ne[1], Q->ne[2]);
5661-
cb(V_repeated, "V_repeated", il);
5662-
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, ggml_repeat(ctx0, V, V_repeated), KQ_soft_max);
5663-
cb(KQV, "KQV", il);
5664-
5665-
// KQV_merged = KQV.permute(0, 2, 1, 3)
5666-
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
5667-
cb(KQV_merged, "KQV_merged", il);
5668-
5669-
// cur = KQV_merged.contiguous().view(n_embd, n_tokens)
5670-
cur = ggml_cont_2d(ctx0, KQV_merged, n_embd, n_tokens);
5671-
cb(cur, "KQV_merged_contiguous", il);
5672-
5673-
// projection (no bias)
5674-
cur = ggml_mul_mat(ctx0,
5634+
struct ggml_tensor * cur = ggml_cont_2d(ctx, kqv_merged, n_embd, n_tokens);
5635+
cb(cur, "kqv_merged_cont", il);
5636+
5637+
cur = ggml_mul_mat(ctx, wo, cur);
5638+
return cur;
5639+
};
5640+
5641+
cur = plamo_llm_build_kqv(ctx0, hparams, kv_self,
56755642
model.layers[il].wo,
5676-
cur);
5677-
cb(cur, "result_wo", il);
5643+
Qcur, KQ_mask, n_ctx, n_tokens, n_kv, cb, il);
5644+
cb(cur, "kqv_out", il);
56785645
}
56795646
struct ggml_tensor * sa_out = cur;
56805647

0 commit comments

Comments
 (0)