@@ -1552,6 +1552,32 @@ static bool llm_load_tensors(
1552
1552
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
1553
1553
}
1554
1554
} break;
1555
+ case LLM_ARCH_COHERE2:
1556
+ {
1557
+ model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0);
1558
+
1559
+ // output
1560
+ model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
1561
+ // init output from the input tok embed
1562
+ model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab },
1563
+ llama_model_loader::TENSOR_DUPLICATED);
1564
+
1565
+ for (int i = 0; i < n_layer; ++i) {
1566
+ auto & layer = model.layers[i];
1567
+
1568
+ layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
1569
+
1570
+ layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd }, 0);
1571
+ layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_gqa }, 0);
1572
+ layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_gqa }, 0);
1573
+ layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0);
1574
+
1575
+ layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0);
1576
+ layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0);
1577
+ layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0);
1578
+ }
1579
+ }
1580
+ break;
1555
1581
case LLM_ARCH_OLMO: // adapted from LLM_ARCH_LLAMA with norm params removed
1556
1582
{
1557
1583
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -7633,6 +7659,137 @@ struct llm_build_context {
7633
7659
7634
7660
}
7635
7661
7662
+ struct ggml_cgraph * build_cohere2() {
7663
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
7664
+
7665
+ const int64_t n_embd_head = hparams.n_embd_head_v;
7666
+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
7667
+ const float f_logit_scale = hparams.f_logit_scale;
7668
+
7669
+ struct ggml_tensor * cur;
7670
+ struct ggml_tensor * inpL;
7671
+
7672
+ inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
7673
+
7674
+ // inp_pos - contains the positions
7675
+ struct ggml_tensor * inp_pos = build_inp_pos();
7676
+
7677
+ // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
7678
+ // cohere2 requires different mask for layers using sliding window (SWA)
7679
+ struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
7680
+ struct ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa();
7681
+
7682
+ // sliding window switch pattern
7683
+ const int32_t sliding_window_pattern = 4;
7684
+
7685
+ for (int il = 0; il < n_layer; ++il) {
7686
+ // three layers sliding window attention (window size 4096) and ROPE
7687
+ // fourth layer uses global attention without positional embeddings
7688
+ const bool is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1);
7689
+ struct ggml_tensor * KQ_mask_l = is_sliding ? KQ_mask_swa : KQ_mask;
7690
+
7691
+ // norm
7692
+ cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM, cb, il);
7693
+ cb(cur, "attn_norm", il);
7694
+ struct ggml_tensor * ffn_inp = cur;
7695
+
7696
+ // self-attention
7697
+ {
7698
+ // rope freq factors for 128k context
7699
+ struct ggml_tensor * rope_factors = build_rope_factors(il);
7700
+
7701
+ // compute Q and K and RoPE them
7702
+ struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
7703
+ cb(Qcur, "Qcur", il);
7704
+ if (model.layers[il].bq) {
7705
+ Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
7706
+ cb(Qcur, "Qcur", il);
7707
+ }
7708
+
7709
+ struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
7710
+ cb(Kcur, "Kcur", il);
7711
+ if (model.layers[il].bk) {
7712
+ Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
7713
+ cb(Kcur, "Kcur", il);
7714
+ }
7715
+
7716
+ struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
7717
+ cb(Vcur, "Vcur", il);
7718
+ if (model.layers[il].bv) {
7719
+ Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
7720
+ cb(Vcur, "Vcur", il);
7721
+ }
7722
+
7723
+ if (is_sliding) {
7724
+ Qcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
7725
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor,
7726
+ beta_fast, beta_slow);
7727
+ cb(Qcur, "Qcur", il);
7728
+
7729
+ Kcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
7730
+ rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor,
7731
+ attn_factor, beta_fast, beta_slow);
7732
+ cb(Kcur, "Kcur", il);
7733
+ } else {
7734
+ // For non-sliding layers, just reshape without applying RoPE
7735
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
7736
+ cb(Qcur, "Qcur", il);
7737
+
7738
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
7739
+ cb(Kcur, "Kcur", il);
7740
+ }
7741
+
7742
+ cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur,
7743
+ KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f / sqrtf(float(n_embd_head)), cb, il);
7744
+ }
7745
+
7746
+ if (il == n_layer - 1) {
7747
+ // skip computing output for unused tokens
7748
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
7749
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
7750
+ inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
7751
+ ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
7752
+ }
7753
+
7754
+ struct ggml_tensor * attn_out = cur;
7755
+
7756
+ // feed-forward network
7757
+ {
7758
+ cur = llm_build_ffn(ctx0, lctx, ffn_inp, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate,
7759
+ NULL, NULL, model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR,
7760
+ cb, il);
7761
+ cb(cur, "ffn_out", il);
7762
+ }
7763
+
7764
+ // add together residual + FFN + self-attention
7765
+ cur = ggml_add(ctx0, cur, inpL);
7766
+ cur = ggml_add(ctx0, cur, attn_out);
7767
+ cur = lctx.cvec.apply_to(ctx0, cur, il);
7768
+ cb(cur, "l_out", il);
7769
+
7770
+ // input for next layer
7771
+ inpL = cur;
7772
+ }
7773
+
7774
+ cur = inpL;
7775
+
7776
+ cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, NULL, LLM_NORM, cb, -1);
7777
+ cb(cur, "result_norm", -1);
7778
+
7779
+ // lm_head
7780
+ cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
7781
+
7782
+ if (f_logit_scale) {
7783
+ cur = ggml_scale(ctx0, cur, f_logit_scale);
7784
+ }
7785
+
7786
+ cb(cur, "result_output", -1);
7787
+
7788
+ ggml_build_forward_expand(gf, cur);
7789
+
7790
+ return gf;
7791
+ }
7792
+
7636
7793
// ref: https://allenai.org/olmo
7637
7794
// based on the original build_llama() function, changes:
7638
7795
// * non-parametric layer norm
@@ -10384,6 +10541,10 @@ static struct ggml_cgraph * llama_build_graph(
10384
10541
{
10385
10542
result = llm.build_command_r();
10386
10543
} break;
10544
+ case LLM_ARCH_COHERE2:
10545
+ {
10546
+ result = llm.build_cohere2();
10547
+ } break;
10387
10548
case LLM_ARCH_DBRX:
10388
10549
{
10389
10550
result = llm.build_dbrx();
0 commit comments