diff --git a/common/arg.cpp b/common/arg.cpp index fa22e86cd14e6..238db672dd1ec 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1346,6 +1346,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.flash_attn = true; } ).set_env("LLAMA_ARG_FLASH_ATTN")); + add_opt(common_arg( + {"-mla", "--mla-attn"}, + string_format("enable Multi-head Latent Attention (default: %s)", params.mla_attn ? "enabled" : "disabled"), + [](common_params & params) { + params.mla_attn = true; + } + ).set_env("LLAMA_ARG_MLA_ATTN")); add_opt(common_arg( {"-p", "--prompt"}, "PROMPT", "prompt to start generation with; for system message, use -sys", diff --git a/common/common.cpp b/common/common.cpp index d4882c5123cce..d3ab321487637 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1098,6 +1098,7 @@ struct llama_context_params common_context_params_to_llama(const common_params & cparams.cb_eval_user_data = params.cb_eval_user_data; cparams.offload_kqv = !params.no_kv_offload; cparams.flash_attn = params.flash_attn; + cparams.mla_attn = params.mla_attn; cparams.no_perf = params.no_perf; if (params.reranking) { diff --git a/common/common.h b/common/common.h index 725b5123d24f9..cd38a646dcfad 100644 --- a/common/common.h +++ b/common/common.h @@ -319,6 +319,7 @@ struct common_params { bool simple_io = false; // improves compatibility with subprocesses and limited consoles bool cont_batching = true; // insert new sequences for decoding on-the-fly bool flash_attn = false; // flash attention + bool mla_attn = false; // MLA attention for deepseek2 bool no_perf = false; // disable performance metrics bool ctx_shift = true; // context shift on inifinite text generation diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index cfe94deaf76ef..daff8f87d6e77 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -4413,6 +4413,25 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter else: return [] + if name.endswith("kv_b_proj.weight"): + name_kb = name.replace("kv_b_proj", "k_b_proj_trans") + + n_head_kv = self.hparams["num_key_value_heads"] + v_head_dim = self.hparams["v_head_dim"] + qk_nope_head_dim = self.hparams["qk_nope_head_dim"] + + assert data_torch.shape[0] == n_head_kv * (v_head_dim + qk_nope_head_dim) + + kv_b = data_torch.view(n_head_kv, v_head_dim + qk_nope_head_dim, data_torch.shape[-1]) + k_b = kv_b[:, :qk_nope_head_dim, :] + k_b_trans = k_b.transpose(1, 2) + k_b_trans = k_b_trans.reshape(n_head_kv * data_torch.shape[-1], qk_nope_head_dim) + + return [ + (self.map_tensor_name(name), data_torch), + (self.map_tensor_name(name_kb), k_b_trans), + ] + return [(self.map_tensor_name(name), data_torch)] def prepare_tensors(self): diff --git a/examples/server/README.md b/examples/server/README.md index a2a0903261e31..043c725d8d548 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -46,6 +46,7 @@ The project is under active development, and we are [looking for feedback and co | `-ub, --ubatch-size N` | physical maximum batch size (default: 512)
(env: LLAMA_ARG_UBATCH) | | `--keep N` | number of tokens to keep from the initial prompt (default: 0, -1 = all) | | `-fa, --flash-attn` | enable Flash Attention (default: disabled)
(env: LLAMA_ARG_FLASH_ATTN) | +| `-mla, --mla-attn` | enable Multi-head Latent Attention (default: disabled)
(env: LLAMA_ARG_MLA_ATTN) | | `--no-perf` | disable internal libllama performance timings (default: false)
(env: LLAMA_ARG_NO_PERF) | | `-e, --escape` | process escapes sequences (\n, \r, \t, \', \", \\) (default: true) | | `--no-escape` | do not process escape sequences | diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 3a52cfd1e39ac..a422f4ea326ec 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -377,6 +377,7 @@ class MODEL_TENSOR(IntEnum): ATTN_Q_B = auto() ATTN_KV_A_MQA = auto() ATTN_KV_B = auto() + ATTN_K_B_TRANS = auto() ATTN_Q_A_NORM = auto() ATTN_KV_A_NORM = auto() FFN_SUB_NORM = auto() @@ -581,6 +582,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ATTN_Q_B: "blk.{bid}.attn_q_b", MODEL_TENSOR.ATTN_KV_A_MQA: "blk.{bid}.attn_kv_a_mqa", MODEL_TENSOR.ATTN_KV_B: "blk.{bid}.attn_kv_b", + MODEL_TENSOR.ATTN_K_B_TRANS: "blk.{bid}.attn_k_b_trans", MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm", MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm", MODEL_TENSOR.ATTN_SUB_NORM: "blk.{bid}.attn_sub_norm", @@ -1451,6 +1453,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ATTN_Q_B, MODEL_TENSOR.ATTN_KV_A_MQA, MODEL_TENSOR.ATTN_KV_B, + MODEL_TENSOR.ATTN_K_B_TRANS, MODEL_TENSOR.ATTN_Q_A_NORM, MODEL_TENSOR.ATTN_KV_A_NORM, MODEL_TENSOR.ATTN_OUT, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 50bef12e3dbe7..0f95e73761c97 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -656,6 +656,10 @@ class TensorNameMap: "model.layers.{bid}.self_attn.kv_b_proj", # deepseek2 ), + MODEL_TENSOR.ATTN_K_B_TRANS: ( + "model.layers.{bid}.self_attn.k_b_proj_trans", # deepseek2 (MLA specific) + ), + MODEL_TENSOR.ATTN_Q_A_NORM: ( "model.layers.{bid}.self_attn.q_a_layernorm", # deepseek2 ), diff --git a/include/llama.h b/include/llama.h index fca2b034ba270..2627efb0bd96a 100644 --- a/include/llama.h +++ b/include/llama.h @@ -355,6 +355,7 @@ extern "C" { bool embeddings; // if true, extract embeddings (together with logits) bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU bool flash_attn; // whether to use flash attention [EXPERIMENTAL] + bool mla_attn; // MLA attention for deepseek2 bool no_perf; // whether to measure performance timings // Abort callback diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 047782e7d0fc8..c4b052ce80b5e 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -1030,6 +1030,7 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" }, { LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" }, { LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" }, + { LLM_TENSOR_ATTN_K_B_TRANS, "blk.%d.attn_k_b_trans" }, { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, @@ -1471,23 +1472,7 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_QKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_FFN_DOWN_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_FFN_GATE_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_FFN_UP_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_Q_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_K_B_TRANS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_DEC_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, diff --git a/src/llama-arch.h b/src/llama-arch.h index 297cfa4dae571..523c1457d7c73 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -299,6 +299,7 @@ enum llm_tensor { LLM_TENSOR_ATTN_Q_B, LLM_TENSOR_ATTN_KV_A_MQA, LLM_TENSOR_ATTN_KV_B, + LLM_TENSOR_ATTN_K_B_TRANS, LLM_TENSOR_ATTN_Q_A_NORM, LLM_TENSOR_ATTN_KV_A_NORM, LLM_TENSOR_ATTN_SUB_NORM, diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 4735e98ea040f..f4abb665d5066 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -37,6 +37,7 @@ llama_context::llama_context( cparams.embeddings = params.embeddings; cparams.offload_kqv = params.offload_kqv; cparams.flash_attn = params.flash_attn; + cparams.mla_attn = params.mla_attn; cparams.no_perf = params.no_perf; cparams.pooling_type = params.pooling_type; cparams.warmup = false; @@ -104,6 +105,7 @@ llama_context::llama_context( LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch); LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn); LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn); + LLAMA_LOG_INFO("%s: mla_attn = %d\n", __func__, cparams.mla_attn); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); @@ -2243,6 +2245,7 @@ llama_context_params llama_context_default_params() { /*.embeddings =*/ false, /*.offload_kqv =*/ true, /*.flash_attn =*/ false, + /*.mla_attn =*/ false, /*.no_perf =*/ true, /*.abort_callback =*/ nullptr, /*.abort_callback_data =*/ nullptr, @@ -2274,6 +2277,19 @@ llama_context * llama_init_from_model( params.flash_attn = false; } + if (params.mla_attn) { + if (model->arch != LLM_ARCH_DEEPSEEK2) { + LLAMA_LOG_WARN("%s: mla_attn is only compatible with Deepseek2 - forcing off\n", __func__); + params.mla_attn = false; + } else if (model->layers[0].wk_b_trans == nullptr) { + LLAMA_LOG_WARN("%s: mla_attn requires a gguf with the new 'attn_k_b_trans' tensor - forcing off\n", __func__); + params.mla_attn = false; + } else if (params.flash_attn) { + LLAMA_LOG_WARN("%s: flash_attn is not compatible with mla_attn - forcing off\n", __func__); + params.flash_attn = false; + } + } + if (ggml_is_quantized(params.type_v) && !params.flash_attn) { LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__); return nullptr; diff --git a/src/llama-cparams.h b/src/llama-cparams.h index 30e550f023a9e..f2309bc8eef67 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -28,6 +28,7 @@ struct llama_cparams { bool causal_attn; bool offload_kqv; bool flash_attn; + bool mla_attn; bool no_perf; bool warmup; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index cec203df49268..264e77d3c898c 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1130,6 +1130,7 @@ ggml_tensor * llm_graph_context::build_attn_mha( ggml_tensor * v, ggml_tensor * kq_b, ggml_tensor * kq_mask, + ggml_tensor * v_mha_proj, bool v_trans, float kq_scale) const { //const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); @@ -1141,11 +1142,18 @@ ggml_tensor * llm_graph_context::build_attn_mha( //const auto & n_embd_head_k = hparams.n_embd_head_k; //const auto & n_embd_head_v = hparams.n_embd_head_v; - const auto n_embd_head_v = v_trans ? v->ne[1] : v->ne[0]; + const auto n_embd = q->ne[0]; + const auto n_tokens = q->ne[1]; + const auto n_head = q->ne[2]; - const auto n_tokens = q->ne[1]; - const auto n_head = q->ne[2]; - const auto n_kv = k->ne[1]; + const auto n_kv = k->ne[1]; + const auto n_head_kv = k->ne[2]; + + // note: when using MLA, the final embedding size will be changed via v_mha_proj + const auto n_embd_head_v = v_mha_proj == nullptr ? v_trans ? v->ne[1] : v->ne[0] : v_mha_proj->ne[1]; + + GGML_ASSERT(k->ne[0] == q->ne[0] && "K and Q embedding size mismatch"); + GGML_ASSERT(k->ne[2] == v->ne[2] && "K and V number of heads mismatch"); ggml_tensor * cur; @@ -1164,12 +1172,29 @@ ggml_tensor * llm_graph_context::build_attn_mha( cur = ggml_reshape_2d(ctx0, cur, n_embd_head_v*n_head, n_tokens); } else { + + // for MQA (ie: GQA with 1 group) we don't need to use a batched matrix multiply + if (n_head_kv == 1) { + q = ggml_view_2d(ctx0, q, + n_embd, n_tokens*n_head, + ggml_row_size(q->type, n_embd), + 0); + } + ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); // note: this op tends to require high floating point range // while for some models F16 is enough, for others it is not, so we default to F32 here ggml_mul_mat_set_prec(kq, GGML_PREC_F32); + if (n_head_kv == 1) { + kq = ggml_view_3d(ctx0, kq, + n_kv, n_tokens, n_head, + ggml_row_size(kq->type, n_kv), + ggml_row_size(kq->type, n_kv)*n_tokens, + 0); + } + if (arch == LLM_ARCH_GROK) { // need to do the following: // multiply by attn_output_multiplyer of 0.08838834764831845 @@ -1200,6 +1225,11 @@ ggml_tensor * llm_graph_context::build_attn_mha( ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq); + // for deepseek MLA we need to "decompress" from MQA back to MHA + if (v_mha_proj) { + kqv = ggml_mul_mat(ctx0, v_mha_proj, kqv); + } + ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3); cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens); @@ -1258,7 +1288,7 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3); //cb(k, "v", il); - ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, false, kq_scale); + ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, nullptr, false, kq_scale); cb(cur, "kqv_out", il); @@ -1397,7 +1427,7 @@ ggml_tensor * llm_graph_context::build_attn( ggml_element_size(kv_self->v_l[il])*n_ctx*n_embd_head_v, 0); - ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_trans, kq_scale); + ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, nullptr, v_trans, kq_scale); cb(cur, "kqv_out", il); if (wo) { @@ -1456,8 +1486,101 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3); //cb(k, "v", il); - ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, false, kq_scale); + ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, nullptr, false, kq_scale); + + cb(cur, "kqv_out", il); + + if (wo) { + cur = build_lora_mm(wo, cur); + } + + if (wo_b) { + //cb(cur, "kqv_wo", il); + } + + if (wo_b) { + cur = ggml_add(ctx0, cur, wo_b); + } + + return cur; +} + +ggml_tensor * llm_graph_context::build_attn_mla( + llm_graph_input_attn_kv_unified * inp, + ggml_cgraph * gf, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * v_mha_proj, + ggml_tensor * q_cur, + ggml_tensor * k_cur, + ggml_tensor * v_cur, + ggml_tensor * kq_b, + float kq_scale, + int il) const { + // these nodes are added to the graph together so that they are not reordered + // by doing so, the number of splits in the graph is reduced + ggml_build_forward_expand(gf, q_cur); + ggml_build_forward_expand(gf, k_cur); + ggml_build_forward_expand(gf, v_cur); + + const llama_kv_cache_unified * kv_self = static_cast(memory); + const auto & n_ctx = cparams.n_ctx; + + const auto kv_lora_rank = hparams.n_lora_kv; + + // note: deepseek with MLA option converts into MQA with larger n_ebed (ie: GQA with 1 group) + const int64_t n_embd_k_cmpr = kv_lora_rank + hparams.n_rot; + const int64_t n_embd_v_cmpr = kv_lora_rank; + + // note: call from llm_build_deepseek2() passes as: {n_embd, n_tokens, n_head} + const auto n_tokens = q_cur->ne[1]; + + // store to KV cache + { + const auto kv_head = kv_self->head; + + GGML_ASSERT(kv_self->size == n_ctx); + + ggml_tensor * k_cache_view = ggml_view_1d(ctx0, kv_self->k_l[il], + n_tokens*n_embd_k_cmpr, + ggml_row_size(kv_self->k_l[il]->type, n_embd_k_cmpr)*kv_head); + //cb(k_cache_view, "k_cache_view", il); + // note: storing RoPE-ed version of K in the KV cache + ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_cur, k_cache_view)); + + v_cur = ggml_reshape_2d(ctx0, v_cur, n_embd_v_cmpr, n_tokens); + + ggml_tensor * v_cache_view = ggml_view_2d(ctx0, kv_self->v_l[il], + n_tokens, n_embd_v_cmpr, + ( n_ctx)*ggml_element_size(kv_self->v_l[il]), + (kv_head)*ggml_element_size(kv_self->v_l[il])); + + v_cur = ggml_transpose(ctx0, v_cur); + //cb(v_cache_view, "v_cache_view", il); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, v_cur, v_cache_view)); + } + + const auto & kq_mask = inp->get_kq_mask(); + + const auto n_kv = kv_self->n; + + ggml_tensor * k = ggml_view_3d(ctx0, kv_self->k_l[il], + n_embd_k_cmpr, n_kv, 1, + ggml_row_size(kv_self->k_l[il]->type, n_embd_k_cmpr), + ggml_row_size(kv_self->k_l[il]->type, n_embd_k_cmpr), + 0); + //cb(k, "k", il); + + struct ggml_tensor * v = ggml_view_3d(ctx0, kv_self->v_l[il], + n_kv, n_embd_v_cmpr, 1, + ggml_element_size(kv_self->v_l[il])*n_ctx, + ggml_element_size(kv_self->v_l[il])*n_ctx, + 0); + //cb(v, "v", il); + + ggml_tensor * cur = build_attn_mha(gf, q_cur, k, v, kq_b, kq_mask, v_mha_proj, true, kq_scale); cb(cur, "kqv_out", il); if (wo) { @@ -1625,4 +1748,3 @@ void llm_graph_context::build_pooling( ggml_build_forward_expand(gf, cur); } - diff --git a/src/llama-graph.h b/src/llama-graph.h index bdf19ed015e35..a52fb5631e123 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -492,6 +492,7 @@ struct llm_graph_context { ggml_tensor * v, // [n_embd_head_v, n_tokens, n_head_v] (v_trans == false) ggml_tensor * kq_b, ggml_tensor * kq_mask, + ggml_tensor * v_mha_proj, bool v_trans, float kq_scale) const; @@ -537,6 +538,19 @@ struct llm_graph_context { float kq_scale, int il) const; + ggml_tensor * build_attn_mla( + llm_graph_input_attn_kv_unified * inp, + ggml_cgraph * gf, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * v_mha_proj, + ggml_tensor * q_cur, // [n_embd_head_q, n_tokens, n_head_q] + ggml_tensor * k_cur, // [n_embd_head_k, n_tokens] + ggml_tensor * v_cur, // [n_embd_head_v, n_tokens] + ggml_tensor * kq_b, + float kq_scale, + int il) const; + // // recurrent // diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index dbf5f1187d9e5..54e1a5d8f4e06 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -11,6 +11,8 @@ #include #include +#include + llama_kv_cache_unified::llama_kv_cache_unified(const llama_hparams & hparams, callbacks cbs) : hparams(hparams), cbs(std::move(cbs)) { } @@ -27,7 +29,7 @@ bool llama_kv_cache_unified::init( recurrent = llama_model_is_recurrent(&model); v_trans = !recurrent && !cparams.flash_attn; - can_shift = !recurrent && model.arch != LLM_ARCH_DEEPSEEK2; // not supported due to MLA + can_shift = !recurrent && model.arch != LLM_ARCH_DEEPSEEK2; // TODO: support DEEPSEEK2 context shifting LLAMA_LOG_INFO("%s: kv_size = %d, offload = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d\n", __func__, kv_size, offload, ggml_type_name(type_k), ggml_type_name(type_v), n_layer, can_shift); @@ -71,8 +73,18 @@ bool llama_kv_cache_unified::init( v_l.reserve(n_layer); for (int i = 0; i < n_layer; i++) { - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); + int64_t n_embd_k; + int64_t n_embd_v; + + // note: be sure to check model.arch or this will cause a bug if used with a non-MLA draft model! + if (cparams.mla_attn && model.arch == LLM_ARCH_DEEPSEEK2) { + // note: deepseek2 with MLA option converts into MQA (ie: GQA with 1 group) + n_embd_k = hparams.n_lora_kv + hparams.n_rot; + n_embd_v = hparams.n_lora_kv; + } else { + n_embd_k = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); + n_embd_v = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); + } const char * dev_name = "CPU"; @@ -86,8 +98,8 @@ bool llama_kv_cache_unified::init( buft = ggml_backend_cpu_buffer_type(); } - LLAMA_LOG_DEBUG("%s: layer %3d: n_embd_k_gqa = %d, n_embd_v_gqa = %d, dev = %s\n", __func__, - i, n_embd_k_gqa, n_embd_v_gqa, dev_name); + LLAMA_LOG_DEBUG("%s: layer %3d: n_embd_k = %" PRId64 ", n_embd_v = %" PRId64 ", dev = %s\n", __func__, + i, n_embd_k, n_embd_v, dev_name); ggml_context * ctx = ctx_for_buft(buft); if (!ctx) { @@ -95,8 +107,8 @@ bool llama_kv_cache_unified::init( return false; } - ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); - ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); + ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k*kv_size); + ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v*kv_size); ggml_format_name(k, "cache_k_l%d", i); ggml_format_name(v, "cache_v_l%d", i); k_l.push_back(k); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index ca6e3ab2caeb1..a6f7faf103298 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -3070,9 +3070,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}, 0); } - layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}, 0); - layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head * ( n_embd_head_v), n_embd}, 0); + layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}, 0); + layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0); + layer.wk_b_trans = create_tensor(tn(LLM_TENSOR_ATTN_K_B_TRANS, "weight", i), {n_embd_head_qk_nope, n_head * kv_lora_rank}, TENSOR_NOT_REQUIRED); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head * ( n_embd_head_v), n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -9509,6 +9510,8 @@ struct llm_build_deepseek2 : public llm_graph_context { const float kq_scale = 1.0f*mscale*mscale/sqrtf(float(hparams.n_embd_head_k)); const float attn_factor_scaled = 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)); + const uint32_t n_embd_head_k = hparams.n_embd_head_k; + const uint32_t n_embd_head_v = hparams.n_embd_head_v; const uint32_t n_embd_head_qk_rope = hparams.n_rot; const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; const uint32_t kv_lora_rank = hparams.n_lora_kv; @@ -9537,16 +9540,14 @@ struct llm_build_deepseek2 : public llm_graph_context { { ggml_tensor * q = NULL; if (!is_lite) { - // {n_embd, q_lora_rank} * {n_embd, n_tokens} -> {q_lora_rank, n_tokens} q = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur); cb(q, "q", il); q = build_norm(q, - model.layers[il].attn_q_a_norm, NULL, + model.layers[il].attn_q_a_norm, nullptr, LLM_NORM_RMS, il); cb(q, "q", il); - // {q_lora_rank, n_head * hparams.n_embd_head_k} * {q_lora_rank, n_tokens} -> {n_head * hparams.n_embd_head_k, n_tokens} q = ggml_mul_mat(ctx0, model.layers[il].wq_b, q); cb(q, "q", il); } else { @@ -9554,96 +9555,150 @@ struct llm_build_deepseek2 : public llm_graph_context { cb(q, "q", il); } - // split into {n_head * n_embd_head_qk_nope, n_tokens} - ggml_tensor * q_nope = ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens, - ggml_row_size(q->type, hparams.n_embd_head_k), - ggml_row_size(q->type, hparams.n_embd_head_k * n_head), + // split into {n_embd_head_qk_nope, n_head, n_tokens} + ggml_tensor * q_nope = ggml_view_3d(ctx0, q, + n_embd_head_qk_nope, n_head, n_tokens, + ggml_row_size(q->type, n_embd_head_k), + ggml_row_size(q->type, n_embd_head_k) * n_head, 0); cb(q_nope, "q_nope", il); - // and {n_head * n_embd_head_qk_rope, n_tokens} - ggml_tensor * q_pe = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens, - ggml_row_size(q->type, hparams.n_embd_head_k), - ggml_row_size(q->type, hparams.n_embd_head_k * n_head), + // and {n_embd_head_qk_rope, n_head, n_tokens} + ggml_tensor * q_pe = ggml_view_3d(ctx0, q, + n_embd_head_qk_rope, n_head, n_tokens, + ggml_row_size(q->type, n_embd_head_k), + ggml_row_size(q->type, n_embd_head_k) * n_head, ggml_row_size(q->type, n_embd_head_qk_nope)); cb(q_pe, "q_pe", il); - // {n_embd, kv_lora_rank + n_embd_head_qk_rope} * {n_embd, n_tokens} -> {kv_lora_rank + n_embd_head_qk_rope, n_tokens} - ggml_tensor * kv_pe_compresseed = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur); - cb(kv_pe_compresseed, "kv_pe_compresseed", il); + ggml_tensor * kv_cmpr_pe = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur); + cb(kv_cmpr_pe, "kv_cmpr_pe", il); // split into {kv_lora_rank, n_tokens} - ggml_tensor * kv_compressed = ggml_view_2d(ctx0, kv_pe_compresseed, kv_lora_rank, n_tokens, - kv_pe_compresseed->nb[1], + ggml_tensor * kv_cmpr = ggml_view_2d(ctx0, kv_cmpr_pe, + kv_lora_rank, n_tokens, + ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), 0); - cb(kv_compressed, "kv_compressed", il); + cb(kv_cmpr, "kv_cmpr", il); + + // and {n_embd_head_qk_rope, 1, n_tokens} + ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_cmpr_pe, + n_embd_head_qk_rope, 1, n_tokens, + ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), + ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), + ggml_row_size(kv_cmpr_pe->type, kv_lora_rank)); + cb(k_pe, "k_pe", il); - // and {n_embd_head_qk_rope, n_tokens} - ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_compresseed, n_embd_head_qk_rope, 1, n_tokens, - kv_pe_compresseed->nb[1], - kv_pe_compresseed->nb[1], - ggml_row_size(kv_pe_compresseed->type, kv_lora_rank)); + // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this + q_pe = ggml_cont(ctx0, q_pe); + q_pe = ggml_rope_ext(ctx0, q_pe, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor_scaled, beta_fast, beta_slow + ); + cb(q_pe, "q_pe", il); + + // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this + k_pe = ggml_cont(ctx0, k_pe); + k_pe = ggml_rope_ext(ctx0, k_pe, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor_scaled, beta_fast, beta_slow + ); cb(k_pe, "k_pe", il); // TODO: the CUDA backend used to not support non-cont. (RMS) norm, investigate removing ggml_cont - kv_compressed = ggml_cont(ctx0, kv_compressed); - kv_compressed = build_norm(kv_compressed, - model.layers[il].attn_kv_a_norm, NULL, + kv_cmpr = ggml_cont(ctx0, kv_cmpr); + kv_cmpr = build_norm(kv_cmpr, + model.layers[il].attn_kv_a_norm, nullptr, LLM_NORM_RMS, il); - cb(kv_compressed, "kv_compressed", il); + cb(kv_cmpr, "kv_cmpr", il); - // {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens} - ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed); - cb(kv, "kv", il); + if (cparams.mla_attn) { + GGML_ASSERT(model.layers[il].wk_b_trans != nullptr); // should not get here, see: llama_init_from_model() - // split into {n_head * n_embd_head_qk_nope, n_tokens} - ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens, - ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v), - ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), - 0); - cb(k_nope, "k_nope", il); + q_nope = ggml_permute(ctx0, q_nope, 0, 2, 1, 3); + cb(q_nope, "q_nope_perm", il); - // and {n_head * n_embd_head_v, n_tokens} - ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens, - ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)), - ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head), - ggml_row_size(kv->type, (n_embd_head_qk_nope))); - cb(v_states, "v_states", il); + q_pe = ggml_permute(ctx0, q_pe, 0, 2, 1, 3); + cb(q_pe, "q_pe_perm", il); - v_states = ggml_cont(ctx0, v_states); - cb(v_states, "v_states", il); + k_pe = ggml_view_2d(ctx0, k_pe, + n_embd_head_qk_rope, n_tokens, + ggml_row_size(k_pe->type, n_embd_head_qk_rope), + 0); + cb(k_pe, "k_pe_view", il); - v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens, - ggml_row_size(kv->type, hparams.n_embd_head_v * n_head), - 0); - cb(v_states, "v_states", il); + ggml_tensor * wk_b_trans = ggml_view_3d(ctx0, model.layers[il].wk_b_trans, + n_embd_head_qk_nope, kv_lora_rank, n_head, + ggml_row_size(model.layers[il].wk_b_trans->type, n_embd_head_qk_nope), + ggml_row_size(model.layers[il].wk_b_trans->type, n_embd_head_qk_nope) * kv_lora_rank, + 0); + cb(wk_b_trans, "wk_b_trans", il); - q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this - q_pe = ggml_rope_ext( - ctx0, q_pe, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor_scaled, beta_fast, beta_slow - ); - cb(q_pe, "q_pe", il); + ggml_tensor * q_nope_absorbed = ggml_mul_mat(ctx0, wk_b_trans, q_nope); + cb(q_nope_absorbed, "q_nope_absorbed", il); - // shared RoPE key - k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this - k_pe = ggml_rope_ext( - ctx0, k_pe, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor_scaled, beta_fast, beta_slow - ); - cb(k_pe, "k_pe", il); + ggml_tensor * q_states = ggml_concat(ctx0, q_nope_absorbed, q_pe, 0); + cb(q_states, "q_states", il); - ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0); - cb(q_states, "q_states", il); + ggml_tensor * k_states = ggml_concat(ctx0, kv_cmpr, k_pe, 0); + cb(k_states, "k_states", il); - ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0); - cb(k_states, "k_states", il); + ggml_tensor * v_states = kv_cmpr; + cb(v_states, "v_states", il); + + // {n_embd_head_v, n_head, n_tokens} + ggml_tensor * wv_b = ggml_view_3d(ctx0, model.layers[il].wkv_b, + kv_lora_rank, n_embd_head_v, n_head, + ggml_row_size(model.layers[il].wkv_b->type, kv_lora_rank), + ggml_row_size(model.layers[il].wkv_b->type, kv_lora_rank) * (n_embd_head_qk_nope + n_embd_head_v), + ggml_row_size(model.layers[il].wkv_b->type, kv_lora_rank) * n_embd_head_qk_nope); + cb(wv_b, "wv_b", il); + + // note: deepseek2 with MLA option converts into MQA (ie: GQA with 1 group) + cur = build_attn_mla(inp_attn, gf, + model.layers[il].wo, NULL, wv_b, + q_states, k_states, v_states, nullptr, kq_scale, il); + } else { + ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cmpr); + cb(kv, "kv", il); + + // split into {n_embd_head_qk_nope, n_head, n_tokens} + ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, + n_embd_head_qk_nope, n_head, n_tokens, + ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v), + ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v) * n_head, + 0); + cb(k_nope, "k_nope", il); + + // and {n_embd_head_v, n_head, n_tokens} + ggml_tensor * v_states = ggml_view_3d(ctx0, kv, + n_embd_head_v, n_head, n_tokens, + ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v), + ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v) * n_head, + ggml_row_size(kv->type, n_embd_head_qk_nope)); + cb(v_states, "v_states", il); + + v_states = ggml_cont(ctx0, v_states); + cb(v_states, "v_states", il); + + v_states = ggml_view_2d(ctx0, v_states, + n_embd_head_v * n_head, n_tokens, + ggml_row_size(v_states->type, n_embd_head_v) * n_head, + 0); + cb(v_states, "v_states", il); + + ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0); + cb(q_states, "q_states", il); + + ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0); + cb(k_states, "k_states", il); + + // note: deepseek2 without MLA option converts into MHA (ie: GQA with full n_head groups) + cur = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + q_states, k_states, v_states, nullptr, kq_scale, il); + } - cur = build_attn(inp_attn, gf, - model.layers[il].wo, NULL, - q_states, k_states, v_states, nullptr, kq_scale, il); } if (il == n_layer - 1) { diff --git a/src/llama-model.h b/src/llama-model.h index 91e6e8725acd2..9e6439c6996b6 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -160,23 +160,24 @@ struct llama_layer { struct ggml_tensor * attn_norm_enc = nullptr; // attention - struct ggml_tensor * wq = nullptr; - struct ggml_tensor * wk = nullptr; - struct ggml_tensor * wv = nullptr; - struct ggml_tensor * wo = nullptr; - struct ggml_tensor * wqkv = nullptr; - struct ggml_tensor * wq_a = nullptr; - struct ggml_tensor * wq_b = nullptr; - struct ggml_tensor * wkv_a_mqa = nullptr; - struct ggml_tensor * wkv_b = nullptr; - struct ggml_tensor * wq_cross = nullptr; - struct ggml_tensor * wk_cross = nullptr; - struct ggml_tensor * wv_cross = nullptr; - struct ggml_tensor * wo_cross = nullptr; - struct ggml_tensor * wq_enc = nullptr; - struct ggml_tensor * wk_enc = nullptr; - struct ggml_tensor * wv_enc = nullptr; - struct ggml_tensor * wo_enc = nullptr; + struct ggml_tensor * wq = nullptr; + struct ggml_tensor * wk = nullptr; + struct ggml_tensor * wv = nullptr; + struct ggml_tensor * wo = nullptr; + struct ggml_tensor * wqkv = nullptr; + struct ggml_tensor * wq_a = nullptr; + struct ggml_tensor * wq_b = nullptr; + struct ggml_tensor * wkv_a_mqa = nullptr; + struct ggml_tensor * wkv_b = nullptr; + struct ggml_tensor * wk_b_trans = nullptr; + struct ggml_tensor * wq_cross = nullptr; + struct ggml_tensor * wk_cross = nullptr; + struct ggml_tensor * wv_cross = nullptr; + struct ggml_tensor * wo_cross = nullptr; + struct ggml_tensor * wq_enc = nullptr; + struct ggml_tensor * wk_enc = nullptr; + struct ggml_tensor * wv_enc = nullptr; + struct ggml_tensor * wo_enc = nullptr; // attention bias struct ggml_tensor * bq = nullptr;