@@ -5520,6 +5520,10 @@ struct llm_build_context {
5520
5520
inpL = llm_build_inp_embd (ctx0, hparams, batch, model.tok_embd , cb);
5521
5521
cb (inpL, " inp_embd" , -1 );
5522
5522
5523
+ // inp_pos - contains the positions
5524
+ struct ggml_tensor * inp_pos = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, n_tokens);
5525
+ cb (inp_pos, " inp_pos" , -1 );
5526
+
5523
5527
// KQ_scale
5524
5528
struct ggml_tensor * KQ_scale = ggml_new_tensor_1d (ctx0, GGML_TYPE_F32, 1 );
5525
5529
cb (KQ_scale, " KQ_scale" , -1 );
@@ -5528,10 +5532,6 @@ struct llm_build_context {
5528
5532
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d (ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1 );
5529
5533
cb (KQ_mask, " KQ_mask" , -1 );
5530
5534
5531
- // inp_pos - contains the positions
5532
- struct ggml_tensor * inp_pos = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, n_tokens);
5533
- cb (inp_pos, " inp_pos" , -1 );
5534
-
5535
5535
// shift the entire K-cache if needed
5536
5536
if (do_rope_shift) {
5537
5537
llm_build_k_shift (ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb);
@@ -5544,137 +5544,104 @@ struct llm_build_context {
5544
5544
cur = llm_build_norm (ctx0, inpL, hparams,
5545
5545
model.layers [il].attn_norm , NULL ,
5546
5546
LLM_NORM_RMS, cb, il);
5547
- cb (cur, " attention_norm_0 " , il);
5547
+ cb (cur, " attention_norm " , il);
5548
5548
5549
5549
struct ggml_tensor * attention_norm = cur;
5550
5550
5551
5551
// self-attention
5552
5552
{
5553
5553
// compute Q and K and RoPE them
5554
- struct ggml_tensor * tmpk = ggml_mul_mat (ctx0, model.layers [il].wk , cur);
5555
- cb (tmpk , " tmpk " , il);
5554
+ struct ggml_tensor * Qcur = ggml_mul_mat (ctx0, model.layers [il].wq , cur);
5555
+ cb (Qcur , " Qcur " , il);
5556
5556
5557
- struct ggml_tensor * tmpq = ggml_mul_mat (ctx0, model.layers [il].wq , cur);
5558
- cb (tmpq , " tmpq " , il);
5557
+ struct ggml_tensor * Kcur = ggml_mul_mat (ctx0, model.layers [il].wk , cur);
5558
+ cb (Kcur , " Kcur " , il);
5559
5559
5560
- struct ggml_tensor * Kcur = ggml_rope_custom (
5561
- ctx0, ggml_reshape_3d (ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), inp_pos,
5560
+ struct ggml_tensor * Vcur = ggml_mul_mat (ctx0, model.layers [il].wv , cur);
5561
+ cb (Vcur, " Vcur" , il);
5562
+
5563
+ Qcur = ggml_rope_custom (
5564
+ ctx0, ggml_reshape_3d (ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
5562
5565
n_embd_head, 2 , 0 , n_orig_ctx, freq_base, freq_scale,
5563
5566
ext_factor, attn_factor, beta_fast, beta_slow);
5564
- cb (Kcur , " Kcur " , il);
5567
+ cb (Qcur , " Qcur " , il);
5565
5568
5566
- struct ggml_tensor * Qcur = ggml_rope_custom (
5567
- ctx0, ggml_reshape_3d (ctx0, tmpq , n_embd_head, n_head, n_tokens), inp_pos,
5569
+ Kcur = ggml_rope_custom (
5570
+ ctx0, ggml_reshape_3d (ctx0, Kcur , n_embd_head, n_head_kv, n_tokens), inp_pos,
5568
5571
n_embd_head, 2 , 0 , n_orig_ctx, freq_base, freq_scale,
5569
5572
ext_factor, attn_factor, beta_fast, beta_slow);
5570
- cb (Qcur , " Qcur " , il);
5573
+ cb (Kcur , " Kcur " , il);
5571
5574
5572
- // store key and value to memory
5573
- {
5574
- // compute the transposed [n_tokens, n_embd] V matrix
5575
+ llm_build_kv_store (ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
5575
5576
5576
- struct ggml_tensor * tmpv = ggml_mul_mat (ctx0, model.layers [il].wv , cur);
5577
- cb (tmpv, " tmpv" , il);
5577
+ auto plamo_llm_build_kqv = [](
5578
+ struct ggml_context * ctx,
5579
+ const llama_hparams & hparams,
5580
+ const llama_kv_cache & kv,
5581
+ struct ggml_tensor * wo,
5582
+ struct ggml_tensor * q_cur,
5583
+ struct ggml_tensor * kq_mask,
5584
+ int64_t n_ctx,
5585
+ int32_t n_tokens,
5586
+ int32_t n_kv,
5587
+ const llm_build_cb & cb,
5588
+ int il) {
5589
+ const int64_t n_embd = hparams.n_embd ;
5590
+ const int64_t n_head_kv = hparams.n_head_kv ;
5591
+ const int64_t n_embd_head = hparams.n_embd_head ();
5592
+ const int64_t n_embd_gqa = hparams.n_embd_gqa ();
5593
+
5594
+ struct ggml_tensor * q = ggml_permute (ctx, q_cur, 0 , 2 , 1 , 3 );
5595
+ cb (q, " q" , il);
5596
+
5597
+ struct ggml_tensor * k =
5598
+ ggml_view_3d (ctx, kv.k_l [il],
5599
+ n_embd_head, n_kv, n_head_kv,
5600
+ ggml_row_size (kv.k_l [il]->type , n_embd_gqa),
5601
+ ggml_row_size (kv.k_l [il]->type , n_embd_head),
5602
+ 0 );
5603
+ cb (k, " k" , il);
5578
5604
5579
- struct ggml_tensor * Vcur = ggml_transpose (ctx0, ggml_reshape_2d (ctx0, tmpv, n_embd_gqa, n_tokens));
5580
- cb (Vcur, " Vcur" , il);
5605
+ // we should avoid to repeat K but current ggml_mul_mat generates wrong values for grouped query att
5606
+ struct ggml_tensor * k_repeated = ggml_new_tensor_3d (ctx, GGML_TYPE_F32, k->ne [0 ], k->ne [1 ], q->ne [2 ]);
5607
+ cb (k_repeated, " k_repeated" , il);
5581
5608
5582
- // struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k_l[il], n_tokens*n_embd_gqa, (ggml_element_size(kv_self.k_l[il])*n_embd_gqa)*(il*n_ctx + kv_head));
5583
- struct ggml_tensor * k = ggml_view_1d (ctx0, kv_self.k_l [il], n_tokens*n_embd_gqa, (ggml_element_size (kv_self.k_l [il])*n_embd_gqa)*kv_head);
5584
- cb (k, " k" , il);
5609
+ struct ggml_tensor * kq = ggml_mul_mat (ctx, ggml_repeat (ctx, k, k_repeated), q);
5610
+ cb (kq, " kq" , il);
5611
+
5612
+ kq = ggml_soft_max_ext (ctx, kq, kq_mask, 1 .0f /sqrtf (float (n_embd_head)));
5613
+ cb (kq, " kq_soft_max_ext" , il);
5585
5614
5586
- /*
5587
- struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa,
5588
- ( n_ctx)*ggml_element_size(kv_self.v),
5589
- (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v));
5590
- */
5591
- struct ggml_tensor * v = ggml_view_2d (ctx0, kv_self.v_l [il], n_tokens, n_embd_gqa,
5592
- n_ctx*ggml_element_size (kv_self.v_l [il]),
5593
- kv_head*ggml_element_size (kv_self.v_l [il]));
5615
+ // split cached v into n_head heads
5616
+ struct ggml_tensor * v =
5617
+ ggml_view_3d (ctx, kv.v_l [il],
5618
+ n_kv, n_embd_head, n_head_kv,
5619
+ ggml_element_size (kv.v_l [il])*n_ctx,
5620
+ ggml_element_size (kv.v_l [il])*n_ctx*n_embd_head,
5621
+ 0 );
5594
5622
cb (v, " v" , il);
5595
5623
5596
- // important: storing RoPE-ed version of K in the KV cache!
5597
- ggml_build_forward_expand (gf, ggml_cpy (ctx0, Kcur, k));
5598
- ggml_build_forward_expand (gf, ggml_cpy (ctx0, Vcur, v));
5599
- }
5624
+ // we should avoid to repeat V but current ggml_mul_mat generates wrong values for grouped query att
5625
+ struct ggml_tensor * v_repeated = ggml_new_tensor_3d (ctx, GGML_TYPE_F32, v->ne [0 ], v->ne [1 ], q->ne [2 ]);
5626
+ cb (k_repeated, " v_repeated" , il);
5600
5627
5601
- struct ggml_tensor * Q = ggml_permute (ctx0, Qcur, 0 , 2 , 1 , 3 );
5602
- cb (Q, " Q" , il);
5628
+ struct ggml_tensor * kqv = ggml_mul_mat (ctx, ggml_repeat (ctx, v, v_repeated), kq);
5629
+ cb (kqv, " kqv" , il);
5630
+
5631
+ struct ggml_tensor * kqv_merged = ggml_permute (ctx, kqv, 0 , 2 , 1 , 3 );
5632
+ cb (kqv_merged, " kqv_merged" , il);
5603
5633
5604
- /*
5605
- struct ggml_tensor * K =
5606
- ggml_view_3d(ctx0, kv_self.k,
5607
- n_embd_head, n_kv, n_head_kv,
5608
- ggml_element_size(kv_self.k)*n_embd_gqa,
5609
- ggml_element_size(kv_self.k)*n_embd_head,
5610
- ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il);
5611
- */
5612
- struct ggml_tensor * K =
5613
- ggml_view_3d (ctx0, kv_self.k_l [il],
5614
- n_embd_head, n_kv, n_head_kv,
5615
- ggml_element_size (kv_self.k_l [il])*n_embd_gqa,
5616
- ggml_element_size (kv_self.k_l [il])*n_embd_head,
5617
- 0 );
5618
- cb (K, " K" , il);
5619
-
5620
- // K * Q
5621
- // struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
5622
- // we should avoid to repeat K but current ggml_mul_mat generates wrong values for grouped query att
5623
- struct ggml_tensor * K_repeated = ggml_new_tensor_3d (ctx0, GGML_TYPE_F32, K->ne [0 ], K->ne [1 ], Q->ne [2 ]);
5624
- cb (K_repeated, " K_repeated" , il);
5625
- struct ggml_tensor * KQ = ggml_mul_mat (ctx0, ggml_repeat (ctx0, K, K_repeated), Q);
5626
- cb (KQ, " KQ" , il);
5627
-
5628
- // KQ_scaled = KQ / sqrt(n_embd_head)
5629
- // KQ_scaled shape [n_kv, n_tokens, n_head, 1]
5630
- struct ggml_tensor * KQ_scaled = ggml_scale (ctx0, KQ, KQ_scale);
5631
- cb (KQ_scaled, " KQ_scaled" , il);
5632
-
5633
- // KQ_masked = mask_past(KQ_scaled)
5634
- struct ggml_tensor * KQ_masked = ggml_add (ctx0, KQ_scaled, KQ_mask);
5635
- cb (KQ_masked, " KQ_masked" , il);
5636
-
5637
- // KQ = soft_max(KQ_masked)
5638
- struct ggml_tensor * KQ_soft_max = ggml_soft_max (ctx0, KQ_masked);
5639
- cb (KQ_soft_max, " KQ_soft_max" , il);
5640
-
5641
- // split cached V into n_head heads
5642
- /*
5643
- struct ggml_tensor * V =
5644
- ggml_view_3d(ctx0, kv_self.v,
5645
- n_kv, n_embd_head, n_head_kv,
5646
- ggml_element_size(kv_self.v)*n_ctx,
5647
- ggml_element_size(kv_self.v)*n_ctx*n_embd_head,
5648
- ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il);
5649
- */
5650
- struct ggml_tensor * V =
5651
- ggml_view_3d (ctx0, kv_self.v_l [il],
5652
- n_kv, n_embd_head, n_head_kv,
5653
- ggml_element_size (kv_self.v_l [il])*n_ctx,
5654
- ggml_element_size (kv_self.v_l [il])*n_ctx*n_embd_head,
5655
- 0 );
5656
- cb (V, " V" , il);
5657
-
5658
- // struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
5659
- // we should avoid to repeat V but current ggml_mul_mat generates wrong values for grouped query att
5660
- struct ggml_tensor * V_repeated = ggml_new_tensor_3d (ctx0, GGML_TYPE_F32, V->ne [0 ], V->ne [1 ], Q->ne [2 ]);
5661
- cb (V_repeated, " V_repeated" , il);
5662
- struct ggml_tensor * KQV = ggml_mul_mat (ctx0, ggml_repeat (ctx0, V, V_repeated), KQ_soft_max);
5663
- cb (KQV, " KQV" , il);
5664
-
5665
- // KQV_merged = KQV.permute(0, 2, 1, 3)
5666
- struct ggml_tensor * KQV_merged = ggml_permute (ctx0, KQV, 0 , 2 , 1 , 3 );
5667
- cb (KQV_merged, " KQV_merged" , il);
5668
-
5669
- // cur = KQV_merged.contiguous().view(n_embd, n_tokens)
5670
- cur = ggml_cont_2d (ctx0, KQV_merged, n_embd, n_tokens);
5671
- cb (cur, " KQV_merged_contiguous" , il);
5672
-
5673
- // projection (no bias)
5674
- cur = ggml_mul_mat (ctx0,
5634
+ struct ggml_tensor * cur = ggml_cont_2d (ctx, kqv_merged, n_embd, n_tokens);
5635
+ cb (cur, " kqv_merged_cont" , il);
5636
+
5637
+ cur = ggml_mul_mat (ctx, wo, cur);
5638
+ return cur;
5639
+ };
5640
+
5641
+ cur = plamo_llm_build_kqv (ctx0, hparams, kv_self,
5675
5642
model.layers [il].wo ,
5676
- cur );
5677
- cb (cur, " result_wo " , il);
5643
+ Qcur, KQ_mask, n_ctx, n_tokens, n_kv, cb, il );
5644
+ cb (cur, " kqv_out " , il);
5678
5645
}
5679
5646
struct ggml_tensor * sa_out = cur;
5680
5647
0 commit comments