diff --git a/common/arg.cpp b/common/arg.cpp
index fa22e86cd14e6..238db672dd1ec 100644
--- a/common/arg.cpp
+++ b/common/arg.cpp
@@ -1346,6 +1346,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.flash_attn = true;
}
).set_env("LLAMA_ARG_FLASH_ATTN"));
+ add_opt(common_arg(
+ {"-mla", "--mla-attn"},
+ string_format("enable Multi-head Latent Attention (default: %s)", params.mla_attn ? "enabled" : "disabled"),
+ [](common_params & params) {
+ params.mla_attn = true;
+ }
+ ).set_env("LLAMA_ARG_MLA_ATTN"));
add_opt(common_arg(
{"-p", "--prompt"}, "PROMPT",
"prompt to start generation with; for system message, use -sys",
diff --git a/common/common.cpp b/common/common.cpp
index d4882c5123cce..d3ab321487637 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -1098,6 +1098,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
cparams.cb_eval_user_data = params.cb_eval_user_data;
cparams.offload_kqv = !params.no_kv_offload;
cparams.flash_attn = params.flash_attn;
+ cparams.mla_attn = params.mla_attn;
cparams.no_perf = params.no_perf;
if (params.reranking) {
diff --git a/common/common.h b/common/common.h
index 725b5123d24f9..cd38a646dcfad 100644
--- a/common/common.h
+++ b/common/common.h
@@ -319,6 +319,7 @@ struct common_params {
bool simple_io = false; // improves compatibility with subprocesses and limited consoles
bool cont_batching = true; // insert new sequences for decoding on-the-fly
bool flash_attn = false; // flash attention
+ bool mla_attn = false; // MLA attention for deepseek2
bool no_perf = false; // disable performance metrics
bool ctx_shift = true; // context shift on inifinite text generation
diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py
index cfe94deaf76ef..daff8f87d6e77 100755
--- a/convert_hf_to_gguf.py
+++ b/convert_hf_to_gguf.py
@@ -4413,6 +4413,25 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
else:
return []
+ if name.endswith("kv_b_proj.weight"):
+ name_kb = name.replace("kv_b_proj", "k_b_proj_trans")
+
+ n_head_kv = self.hparams["num_key_value_heads"]
+ v_head_dim = self.hparams["v_head_dim"]
+ qk_nope_head_dim = self.hparams["qk_nope_head_dim"]
+
+ assert data_torch.shape[0] == n_head_kv * (v_head_dim + qk_nope_head_dim)
+
+ kv_b = data_torch.view(n_head_kv, v_head_dim + qk_nope_head_dim, data_torch.shape[-1])
+ k_b = kv_b[:, :qk_nope_head_dim, :]
+ k_b_trans = k_b.transpose(1, 2)
+ k_b_trans = k_b_trans.reshape(n_head_kv * data_torch.shape[-1], qk_nope_head_dim)
+
+ return [
+ (self.map_tensor_name(name), data_torch),
+ (self.map_tensor_name(name_kb), k_b_trans),
+ ]
+
return [(self.map_tensor_name(name), data_torch)]
def prepare_tensors(self):
diff --git a/examples/server/README.md b/examples/server/README.md
index a2a0903261e31..043c725d8d548 100644
--- a/examples/server/README.md
+++ b/examples/server/README.md
@@ -46,6 +46,7 @@ The project is under active development, and we are [looking for feedback and co
| `-ub, --ubatch-size N` | physical maximum batch size (default: 512)
(env: LLAMA_ARG_UBATCH) |
| `--keep N` | number of tokens to keep from the initial prompt (default: 0, -1 = all) |
| `-fa, --flash-attn` | enable Flash Attention (default: disabled)
(env: LLAMA_ARG_FLASH_ATTN) |
+| `-mla, --mla-attn` | enable Multi-head Latent Attention (default: disabled)
(env: LLAMA_ARG_MLA_ATTN) |
| `--no-perf` | disable internal libllama performance timings (default: false)
(env: LLAMA_ARG_NO_PERF) |
| `-e, --escape` | process escapes sequences (\n, \r, \t, \', \", \\) (default: true) |
| `--no-escape` | do not process escape sequences |
diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py
index 3a52cfd1e39ac..a422f4ea326ec 100644
--- a/gguf-py/gguf/constants.py
+++ b/gguf-py/gguf/constants.py
@@ -377,6 +377,7 @@ class MODEL_TENSOR(IntEnum):
ATTN_Q_B = auto()
ATTN_KV_A_MQA = auto()
ATTN_KV_B = auto()
+ ATTN_K_B_TRANS = auto()
ATTN_Q_A_NORM = auto()
ATTN_KV_A_NORM = auto()
FFN_SUB_NORM = auto()
@@ -581,6 +582,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.ATTN_Q_B: "blk.{bid}.attn_q_b",
MODEL_TENSOR.ATTN_KV_A_MQA: "blk.{bid}.attn_kv_a_mqa",
MODEL_TENSOR.ATTN_KV_B: "blk.{bid}.attn_kv_b",
+ MODEL_TENSOR.ATTN_K_B_TRANS: "blk.{bid}.attn_k_b_trans",
MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm",
MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm",
MODEL_TENSOR.ATTN_SUB_NORM: "blk.{bid}.attn_sub_norm",
@@ -1451,6 +1453,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.ATTN_Q_B,
MODEL_TENSOR.ATTN_KV_A_MQA,
MODEL_TENSOR.ATTN_KV_B,
+ MODEL_TENSOR.ATTN_K_B_TRANS,
MODEL_TENSOR.ATTN_Q_A_NORM,
MODEL_TENSOR.ATTN_KV_A_NORM,
MODEL_TENSOR.ATTN_OUT,
diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py
index 50bef12e3dbe7..0f95e73761c97 100644
--- a/gguf-py/gguf/tensor_mapping.py
+++ b/gguf-py/gguf/tensor_mapping.py
@@ -656,6 +656,10 @@ class TensorNameMap:
"model.layers.{bid}.self_attn.kv_b_proj", # deepseek2
),
+ MODEL_TENSOR.ATTN_K_B_TRANS: (
+ "model.layers.{bid}.self_attn.k_b_proj_trans", # deepseek2 (MLA specific)
+ ),
+
MODEL_TENSOR.ATTN_Q_A_NORM: (
"model.layers.{bid}.self_attn.q_a_layernorm", # deepseek2
),
diff --git a/include/llama.h b/include/llama.h
index fca2b034ba270..2627efb0bd96a 100644
--- a/include/llama.h
+++ b/include/llama.h
@@ -355,6 +355,7 @@ extern "C" {
bool embeddings; // if true, extract embeddings (together with logits)
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
+ bool mla_attn; // MLA attention for deepseek2
bool no_perf; // whether to measure performance timings
// Abort callback
diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp
index 047782e7d0fc8..c4b052ce80b5e 100644
--- a/src/llama-arch.cpp
+++ b/src/llama-arch.cpp
@@ -1030,6 +1030,7 @@ static const std::map> LLM_TENSOR_N
{ LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" },
{ LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" },
{ LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" },
+ { LLM_TENSOR_ATTN_K_B_TRANS, "blk.%d.attn_k_b_trans" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
@@ -1471,23 +1472,7 @@ static const std::map LLM_TENSOR_INFOS = {
{LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
- {LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
- {LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
- {LLM_TENSOR_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
- {LLM_TENSOR_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
- {LLM_TENSOR_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
- {LLM_TENSOR_ATTN_QKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
- {LLM_TENSOR_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
- {LLM_TENSOR_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
- {LLM_TENSOR_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
- {LLM_TENSOR_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
- {LLM_TENSOR_FFN_DOWN_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
- {LLM_TENSOR_FFN_GATE_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
- {LLM_TENSOR_FFN_UP_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
- {LLM_TENSOR_ATTN_Q_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
- {LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
- {LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
- {LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+ {LLM_TENSOR_ATTN_K_B_TRANS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_DEC_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
diff --git a/src/llama-arch.h b/src/llama-arch.h
index 297cfa4dae571..523c1457d7c73 100644
--- a/src/llama-arch.h
+++ b/src/llama-arch.h
@@ -299,6 +299,7 @@ enum llm_tensor {
LLM_TENSOR_ATTN_Q_B,
LLM_TENSOR_ATTN_KV_A_MQA,
LLM_TENSOR_ATTN_KV_B,
+ LLM_TENSOR_ATTN_K_B_TRANS,
LLM_TENSOR_ATTN_Q_A_NORM,
LLM_TENSOR_ATTN_KV_A_NORM,
LLM_TENSOR_ATTN_SUB_NORM,
diff --git a/src/llama-context.cpp b/src/llama-context.cpp
index 4735e98ea040f..f4abb665d5066 100644
--- a/src/llama-context.cpp
+++ b/src/llama-context.cpp
@@ -37,6 +37,7 @@ llama_context::llama_context(
cparams.embeddings = params.embeddings;
cparams.offload_kqv = params.offload_kqv;
cparams.flash_attn = params.flash_attn;
+ cparams.mla_attn = params.mla_attn;
cparams.no_perf = params.no_perf;
cparams.pooling_type = params.pooling_type;
cparams.warmup = false;
@@ -104,6 +105,7 @@ llama_context::llama_context(
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn);
LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn);
+ LLAMA_LOG_INFO("%s: mla_attn = %d\n", __func__, cparams.mla_attn);
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
@@ -2243,6 +2245,7 @@ llama_context_params llama_context_default_params() {
/*.embeddings =*/ false,
/*.offload_kqv =*/ true,
/*.flash_attn =*/ false,
+ /*.mla_attn =*/ false,
/*.no_perf =*/ true,
/*.abort_callback =*/ nullptr,
/*.abort_callback_data =*/ nullptr,
@@ -2274,6 +2277,19 @@ llama_context * llama_init_from_model(
params.flash_attn = false;
}
+ if (params.mla_attn) {
+ if (model->arch != LLM_ARCH_DEEPSEEK2) {
+ LLAMA_LOG_WARN("%s: mla_attn is only compatible with Deepseek2 - forcing off\n", __func__);
+ params.mla_attn = false;
+ } else if (model->layers[0].wk_b_trans == nullptr) {
+ LLAMA_LOG_WARN("%s: mla_attn requires a gguf with the new 'attn_k_b_trans' tensor - forcing off\n", __func__);
+ params.mla_attn = false;
+ } else if (params.flash_attn) {
+ LLAMA_LOG_WARN("%s: flash_attn is not compatible with mla_attn - forcing off\n", __func__);
+ params.flash_attn = false;
+ }
+ }
+
if (ggml_is_quantized(params.type_v) && !params.flash_attn) {
LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__);
return nullptr;
diff --git a/src/llama-cparams.h b/src/llama-cparams.h
index 30e550f023a9e..f2309bc8eef67 100644
--- a/src/llama-cparams.h
+++ b/src/llama-cparams.h
@@ -28,6 +28,7 @@ struct llama_cparams {
bool causal_attn;
bool offload_kqv;
bool flash_attn;
+ bool mla_attn;
bool no_perf;
bool warmup;
diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp
index cec203df49268..264e77d3c898c 100644
--- a/src/llama-graph.cpp
+++ b/src/llama-graph.cpp
@@ -1130,6 +1130,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
ggml_tensor * v,
ggml_tensor * kq_b,
ggml_tensor * kq_mask,
+ ggml_tensor * v_mha_proj,
bool v_trans,
float kq_scale) const {
//const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
@@ -1141,11 +1142,18 @@ ggml_tensor * llm_graph_context::build_attn_mha(
//const auto & n_embd_head_k = hparams.n_embd_head_k;
//const auto & n_embd_head_v = hparams.n_embd_head_v;
- const auto n_embd_head_v = v_trans ? v->ne[1] : v->ne[0];
+ const auto n_embd = q->ne[0];
+ const auto n_tokens = q->ne[1];
+ const auto n_head = q->ne[2];
- const auto n_tokens = q->ne[1];
- const auto n_head = q->ne[2];
- const auto n_kv = k->ne[1];
+ const auto n_kv = k->ne[1];
+ const auto n_head_kv = k->ne[2];
+
+ // note: when using MLA, the final embedding size will be changed via v_mha_proj
+ const auto n_embd_head_v = v_mha_proj == nullptr ? v_trans ? v->ne[1] : v->ne[0] : v_mha_proj->ne[1];
+
+ GGML_ASSERT(k->ne[0] == q->ne[0] && "K and Q embedding size mismatch");
+ GGML_ASSERT(k->ne[2] == v->ne[2] && "K and V number of heads mismatch");
ggml_tensor * cur;
@@ -1164,12 +1172,29 @@ ggml_tensor * llm_graph_context::build_attn_mha(
cur = ggml_reshape_2d(ctx0, cur, n_embd_head_v*n_head, n_tokens);
} else {
+
+ // for MQA (ie: GQA with 1 group) we don't need to use a batched matrix multiply
+ if (n_head_kv == 1) {
+ q = ggml_view_2d(ctx0, q,
+ n_embd, n_tokens*n_head,
+ ggml_row_size(q->type, n_embd),
+ 0);
+ }
+
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
// note: this op tends to require high floating point range
// while for some models F16 is enough, for others it is not, so we default to F32 here
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
+ if (n_head_kv == 1) {
+ kq = ggml_view_3d(ctx0, kq,
+ n_kv, n_tokens, n_head,
+ ggml_row_size(kq->type, n_kv),
+ ggml_row_size(kq->type, n_kv)*n_tokens,
+ 0);
+ }
+
if (arch == LLM_ARCH_GROK) {
// need to do the following:
// multiply by attn_output_multiplyer of 0.08838834764831845
@@ -1200,6 +1225,11 @@ ggml_tensor * llm_graph_context::build_attn_mha(
ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
+ // for deepseek MLA we need to "decompress" from MQA back to MHA
+ if (v_mha_proj) {
+ kqv = ggml_mul_mat(ctx0, v_mha_proj, kqv);
+ }
+
ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens);
@@ -1258,7 +1288,7 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
//cb(k, "v", il);
- ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, false, kq_scale);
+ ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, nullptr, false, kq_scale);
cb(cur, "kqv_out", il);
@@ -1397,7 +1427,7 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_element_size(kv_self->v_l[il])*n_ctx*n_embd_head_v,
0);
- ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_trans, kq_scale);
+ ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, nullptr, v_trans, kq_scale);
cb(cur, "kqv_out", il);
if (wo) {
@@ -1456,8 +1486,101 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
//cb(k, "v", il);
- ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, false, kq_scale);
+ ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, nullptr, false, kq_scale);
+
+ cb(cur, "kqv_out", il);
+
+ if (wo) {
+ cur = build_lora_mm(wo, cur);
+ }
+
+ if (wo_b) {
+ //cb(cur, "kqv_wo", il);
+ }
+
+ if (wo_b) {
+ cur = ggml_add(ctx0, cur, wo_b);
+ }
+
+ return cur;
+}
+
+ggml_tensor * llm_graph_context::build_attn_mla(
+ llm_graph_input_attn_kv_unified * inp,
+ ggml_cgraph * gf,
+ ggml_tensor * wo,
+ ggml_tensor * wo_b,
+ ggml_tensor * v_mha_proj,
+ ggml_tensor * q_cur,
+ ggml_tensor * k_cur,
+ ggml_tensor * v_cur,
+ ggml_tensor * kq_b,
+ float kq_scale,
+ int il) const {
+ // these nodes are added to the graph together so that they are not reordered
+ // by doing so, the number of splits in the graph is reduced
+ ggml_build_forward_expand(gf, q_cur);
+ ggml_build_forward_expand(gf, k_cur);
+ ggml_build_forward_expand(gf, v_cur);
+
+ const llama_kv_cache_unified * kv_self = static_cast(memory);
+ const auto & n_ctx = cparams.n_ctx;
+
+ const auto kv_lora_rank = hparams.n_lora_kv;
+
+ // note: deepseek with MLA option converts into MQA with larger n_ebed (ie: GQA with 1 group)
+ const int64_t n_embd_k_cmpr = kv_lora_rank + hparams.n_rot;
+ const int64_t n_embd_v_cmpr = kv_lora_rank;
+
+ // note: call from llm_build_deepseek2() passes as: {n_embd, n_tokens, n_head}
+ const auto n_tokens = q_cur->ne[1];
+
+ // store to KV cache
+ {
+ const auto kv_head = kv_self->head;
+
+ GGML_ASSERT(kv_self->size == n_ctx);
+
+ ggml_tensor * k_cache_view = ggml_view_1d(ctx0, kv_self->k_l[il],
+ n_tokens*n_embd_k_cmpr,
+ ggml_row_size(kv_self->k_l[il]->type, n_embd_k_cmpr)*kv_head);
+ //cb(k_cache_view, "k_cache_view", il);
+ // note: storing RoPE-ed version of K in the KV cache
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_cur, k_cache_view));
+
+ v_cur = ggml_reshape_2d(ctx0, v_cur, n_embd_v_cmpr, n_tokens);
+
+ ggml_tensor * v_cache_view = ggml_view_2d(ctx0, kv_self->v_l[il],
+ n_tokens, n_embd_v_cmpr,
+ ( n_ctx)*ggml_element_size(kv_self->v_l[il]),
+ (kv_head)*ggml_element_size(kv_self->v_l[il]));
+
+ v_cur = ggml_transpose(ctx0, v_cur);
+ //cb(v_cache_view, "v_cache_view", il);
+
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, v_cur, v_cache_view));
+ }
+
+ const auto & kq_mask = inp->get_kq_mask();
+
+ const auto n_kv = kv_self->n;
+
+ ggml_tensor * k = ggml_view_3d(ctx0, kv_self->k_l[il],
+ n_embd_k_cmpr, n_kv, 1,
+ ggml_row_size(kv_self->k_l[il]->type, n_embd_k_cmpr),
+ ggml_row_size(kv_self->k_l[il]->type, n_embd_k_cmpr),
+ 0);
+ //cb(k, "k", il);
+
+ struct ggml_tensor * v = ggml_view_3d(ctx0, kv_self->v_l[il],
+ n_kv, n_embd_v_cmpr, 1,
+ ggml_element_size(kv_self->v_l[il])*n_ctx,
+ ggml_element_size(kv_self->v_l[il])*n_ctx,
+ 0);
+ //cb(v, "v", il);
+
+ ggml_tensor * cur = build_attn_mha(gf, q_cur, k, v, kq_b, kq_mask, v_mha_proj, true, kq_scale);
cb(cur, "kqv_out", il);
if (wo) {
@@ -1625,4 +1748,3 @@ void llm_graph_context::build_pooling(
ggml_build_forward_expand(gf, cur);
}
-
diff --git a/src/llama-graph.h b/src/llama-graph.h
index bdf19ed015e35..a52fb5631e123 100644
--- a/src/llama-graph.h
+++ b/src/llama-graph.h
@@ -492,6 +492,7 @@ struct llm_graph_context {
ggml_tensor * v, // [n_embd_head_v, n_tokens, n_head_v] (v_trans == false)
ggml_tensor * kq_b,
ggml_tensor * kq_mask,
+ ggml_tensor * v_mha_proj,
bool v_trans,
float kq_scale) const;
@@ -537,6 +538,19 @@ struct llm_graph_context {
float kq_scale,
int il) const;
+ ggml_tensor * build_attn_mla(
+ llm_graph_input_attn_kv_unified * inp,
+ ggml_cgraph * gf,
+ ggml_tensor * wo,
+ ggml_tensor * wo_b,
+ ggml_tensor * v_mha_proj,
+ ggml_tensor * q_cur, // [n_embd_head_q, n_tokens, n_head_q]
+ ggml_tensor * k_cur, // [n_embd_head_k, n_tokens]
+ ggml_tensor * v_cur, // [n_embd_head_v, n_tokens]
+ ggml_tensor * kq_b,
+ float kq_scale,
+ int il) const;
+
//
// recurrent
//
diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp
index dbf5f1187d9e5..54e1a5d8f4e06 100644
--- a/src/llama-kv-cache.cpp
+++ b/src/llama-kv-cache.cpp
@@ -11,6 +11,8 @@
#include