Skip to content

DeepSeek V2/V3 with -mla option #12725

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
b4c169f
Initial commit with all but the MLA graph code done
jukofyork Apr 2, 2025
10207b4
Fixes
jukofyork Apr 2, 2025
ea3c05b
Just make `uint32_t n_embd_k` and `uint32_t n_embd_v`
jukofyork Apr 2, 2025
1f604a7
First working version
jukofyork Apr 2, 2025
1de077b
Fixed return bug in `DeepseekV2Model`
jukofyork Apr 2, 2025
7f92e7b
Minor fixes
jukofyork Apr 2, 2025
319e3ef
More fixes
jukofyork Apr 2, 2025
ee4b389
Renamed `wv_b` to `wv_decompress` to avoid confusion with `_b` biases
jukofyork Apr 2, 2025
c00cd9e
Better `_compressed` variable names
jukofyork Apr 2, 2025
55ad3a7
Minor comment and variable name fixes
jukofyork Apr 2, 2025
0c86f56
Moved `build_attn_mla` to better location
jukofyork Apr 2, 2025
b0c8a43
Removed `gguf.MODEL_TENSOR.ATTN_K_B` from `prepare_tensors()` for now
jukofyork Apr 2, 2025
8c329bc
Bumped `wkv_b` and `wk_b` to use F32.
jukofyork Apr 2, 2025
68302ee
Use `ggml_mul_mat_set_prec` `GGML_PREC_F32` by default for now
jukofyork Apr 2, 2025
937a48d
Better/shorter variable names and more tidying up of code
jukofyork Apr 2, 2025
1fd0aab
Fixed `kv_cmpr_pe` name
jukofyork Apr 2, 2025
4fb439f
Added `n_embd_head_k` as constant
jukofyork Apr 2, 2025
f9a0ef4
Fixed to use `build_attn_mha()` now
jukofyork Apr 3, 2025
5fe402a
`mla_attn` on then not `flash_attn` so we can run `-fa` for draft models
jukofyork Apr 3, 2025
9b862f9
"flash_attn is not compatible with mla_attn" --> flash_attn off
jukofyork Apr 3, 2025
8e23e0d
Fixed subtle bug caused by `-mla` for speculative models
jukofyork Apr 3, 2025
b384086
Removed need for `v_b_proj` storing. Tidied all ggml_row_size for quants
jukofyork Apr 4, 2025
5dbf99c
Removed both calls to `ggml_mul_mat_set_prec` for MLA and non-MLA cases
jukofyork Apr 4, 2025
f0d514a
Merge branch 'ggml-org:master' into mainline-llama-cpp-master--mla
jukofyork Apr 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1346,6 +1346,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.flash_attn = true;
}
).set_env("LLAMA_ARG_FLASH_ATTN"));
add_opt(common_arg(
{"-mla", "--mla-attn"},
string_format("enable Multi-head Latent Attention (default: %s)", params.mla_attn ? "enabled" : "disabled"),
[](common_params & params) {
params.mla_attn = true;
}
).set_env("LLAMA_ARG_MLA_ATTN"));
add_opt(common_arg(
{"-p", "--prompt"}, "PROMPT",
"prompt to start generation with; for system message, use -sys",
Expand Down
1 change: 1 addition & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1098,6 +1098,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
cparams.cb_eval_user_data = params.cb_eval_user_data;
cparams.offload_kqv = !params.no_kv_offload;
cparams.flash_attn = params.flash_attn;
cparams.mla_attn = params.mla_attn;
cparams.no_perf = params.no_perf;

if (params.reranking) {
Expand Down
1 change: 1 addition & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,7 @@ struct common_params {
bool simple_io = false; // improves compatibility with subprocesses and limited consoles
bool cont_batching = true; // insert new sequences for decoding on-the-fly
bool flash_attn = false; // flash attention
bool mla_attn = false; // MLA attention for deepseek2
bool no_perf = false; // disable performance metrics
bool ctx_shift = true; // context shift on inifinite text generation

Expand Down
19 changes: 19 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4413,6 +4413,25 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
else:
return []

if name.endswith("kv_b_proj.weight"):
name_kb = name.replace("kv_b_proj", "k_b_proj_trans")

n_head_kv = self.hparams["num_key_value_heads"]
v_head_dim = self.hparams["v_head_dim"]
qk_nope_head_dim = self.hparams["qk_nope_head_dim"]

assert data_torch.shape[0] == n_head_kv * (v_head_dim + qk_nope_head_dim)

kv_b = data_torch.view(n_head_kv, v_head_dim + qk_nope_head_dim, data_torch.shape[-1])
k_b = kv_b[:, :qk_nope_head_dim, :]
k_b_trans = k_b.transpose(1, 2)
k_b_trans = k_b_trans.reshape(n_head_kv * data_torch.shape[-1], qk_nope_head_dim)

return [
(self.map_tensor_name(name), data_torch),
(self.map_tensor_name(name_kb), k_b_trans),
]

return [(self.map_tensor_name(name), data_torch)]

def prepare_tensors(self):
Expand Down
1 change: 1 addition & 0 deletions examples/server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ The project is under active development, and we are [looking for feedback and co
| `-ub, --ubatch-size N` | physical maximum batch size (default: 512)<br/>(env: LLAMA_ARG_UBATCH) |
| `--keep N` | number of tokens to keep from the initial prompt (default: 0, -1 = all) |
| `-fa, --flash-attn` | enable Flash Attention (default: disabled)<br/>(env: LLAMA_ARG_FLASH_ATTN) |
| `-mla, --mla-attn` | enable Multi-head Latent Attention (default: disabled)<br/>(env: LLAMA_ARG_MLA_ATTN) |
| `--no-perf` | disable internal libllama performance timings (default: false)<br/>(env: LLAMA_ARG_NO_PERF) |
| `-e, --escape` | process escapes sequences (\n, \r, \t, \', \", \\) (default: true) |
| `--no-escape` | do not process escape sequences |
Expand Down
3 changes: 3 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,7 @@ class MODEL_TENSOR(IntEnum):
ATTN_Q_B = auto()
ATTN_KV_A_MQA = auto()
ATTN_KV_B = auto()
ATTN_K_B_TRANS = auto()
ATTN_Q_A_NORM = auto()
ATTN_KV_A_NORM = auto()
FFN_SUB_NORM = auto()
Expand Down Expand Up @@ -581,6 +582,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.ATTN_Q_B: "blk.{bid}.attn_q_b",
MODEL_TENSOR.ATTN_KV_A_MQA: "blk.{bid}.attn_kv_a_mqa",
MODEL_TENSOR.ATTN_KV_B: "blk.{bid}.attn_kv_b",
MODEL_TENSOR.ATTN_K_B_TRANS: "blk.{bid}.attn_k_b_trans",
MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm",
MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm",
MODEL_TENSOR.ATTN_SUB_NORM: "blk.{bid}.attn_sub_norm",
Expand Down Expand Up @@ -1451,6 +1453,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.ATTN_Q_B,
MODEL_TENSOR.ATTN_KV_A_MQA,
MODEL_TENSOR.ATTN_KV_B,
MODEL_TENSOR.ATTN_K_B_TRANS,
MODEL_TENSOR.ATTN_Q_A_NORM,
MODEL_TENSOR.ATTN_KV_A_NORM,
MODEL_TENSOR.ATTN_OUT,
Expand Down
4 changes: 4 additions & 0 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,10 @@ class TensorNameMap:
"model.layers.{bid}.self_attn.kv_b_proj", # deepseek2
),

MODEL_TENSOR.ATTN_K_B_TRANS: (
"model.layers.{bid}.self_attn.k_b_proj_trans", # deepseek2 (MLA specific)
),

MODEL_TENSOR.ATTN_Q_A_NORM: (
"model.layers.{bid}.self_attn.q_a_layernorm", # deepseek2
),
Expand Down
1 change: 1 addition & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,7 @@ extern "C" {
bool embeddings; // if true, extract embeddings (together with logits)
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
bool mla_attn; // MLA attention for deepseek2
bool no_perf; // whether to measure performance timings

// Abort callback
Expand Down
19 changes: 2 additions & 17 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1030,6 +1030,7 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" },
{ LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" },
{ LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" },
{ LLM_TENSOR_ATTN_K_B_TRANS, "blk.%d.attn_k_b_trans" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
Expand Down Expand Up @@ -1471,23 +1472,7 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_QKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_DOWN_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_GATE_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_UP_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_Q_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
Comment on lines -1474 to -1490
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these are deleted inadvertently? For example, ffn_*_shexp are still used by qwen moe

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these were all accidentally duplicated in the main branch so I removed the duplicates when inserting the new ones.

{LLM_TENSOR_ATTN_K_B_TRANS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_DEC_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
Expand Down
1 change: 1 addition & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ enum llm_tensor {
LLM_TENSOR_ATTN_Q_B,
LLM_TENSOR_ATTN_KV_A_MQA,
LLM_TENSOR_ATTN_KV_B,
LLM_TENSOR_ATTN_K_B_TRANS,
LLM_TENSOR_ATTN_Q_A_NORM,
LLM_TENSOR_ATTN_KV_A_NORM,
LLM_TENSOR_ATTN_SUB_NORM,
Expand Down
16 changes: 16 additions & 0 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ llama_context::llama_context(
cparams.embeddings = params.embeddings;
cparams.offload_kqv = params.offload_kqv;
cparams.flash_attn = params.flash_attn;
cparams.mla_attn = params.mla_attn;
cparams.no_perf = params.no_perf;
cparams.pooling_type = params.pooling_type;
cparams.warmup = false;
Expand Down Expand Up @@ -104,6 +105,7 @@ llama_context::llama_context(
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn);
LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn);
LLAMA_LOG_INFO("%s: mla_attn = %d\n", __func__, cparams.mla_attn);
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);

Expand Down Expand Up @@ -2243,6 +2245,7 @@ llama_context_params llama_context_default_params() {
/*.embeddings =*/ false,
/*.offload_kqv =*/ true,
/*.flash_attn =*/ false,
/*.mla_attn =*/ false,
/*.no_perf =*/ true,
/*.abort_callback =*/ nullptr,
/*.abort_callback_data =*/ nullptr,
Expand Down Expand Up @@ -2274,6 +2277,19 @@ llama_context * llama_init_from_model(
params.flash_attn = false;
}

if (params.mla_attn) {
if (model->arch != LLM_ARCH_DEEPSEEK2) {
LLAMA_LOG_WARN("%s: mla_attn is only compatible with Deepseek2 - forcing off\n", __func__);
params.mla_attn = false;
} else if (model->layers[0].wk_b_trans == nullptr) {
LLAMA_LOG_WARN("%s: mla_attn requires a gguf with the new 'attn_k_b_trans' tensor - forcing off\n", __func__);
params.mla_attn = false;
} else if (params.flash_attn) {
LLAMA_LOG_WARN("%s: flash_attn is not compatible with mla_attn - forcing off\n", __func__);
params.flash_attn = false;
}
}

if (ggml_is_quantized(params.type_v) && !params.flash_attn) {
LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__);
return nullptr;
Expand Down
1 change: 1 addition & 0 deletions src/llama-cparams.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ struct llama_cparams {
bool causal_attn;
bool offload_kqv;
bool flash_attn;
bool mla_attn;
bool no_perf;
bool warmup;

Expand Down
138 changes: 130 additions & 8 deletions src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1130,6 +1130,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
ggml_tensor * v,
ggml_tensor * kq_b,
ggml_tensor * kq_mask,
ggml_tensor * v_mha_proj,
bool v_trans,
float kq_scale) const {
//const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
Expand All @@ -1141,11 +1142,18 @@ ggml_tensor * llm_graph_context::build_attn_mha(
//const auto & n_embd_head_k = hparams.n_embd_head_k;
//const auto & n_embd_head_v = hparams.n_embd_head_v;

const auto n_embd_head_v = v_trans ? v->ne[1] : v->ne[0];
const auto n_embd = q->ne[0];
const auto n_tokens = q->ne[1];
const auto n_head = q->ne[2];

const auto n_tokens = q->ne[1];
const auto n_head = q->ne[2];
const auto n_kv = k->ne[1];
const auto n_kv = k->ne[1];
const auto n_head_kv = k->ne[2];

// note: when using MLA, the final embedding size will be changed via v_mha_proj
const auto n_embd_head_v = v_mha_proj == nullptr ? v_trans ? v->ne[1] : v->ne[0] : v_mha_proj->ne[1];

GGML_ASSERT(k->ne[0] == q->ne[0] && "K and Q embedding size mismatch");
GGML_ASSERT(k->ne[2] == v->ne[2] && "K and V number of heads mismatch");

ggml_tensor * cur;

Expand All @@ -1164,12 +1172,29 @@ ggml_tensor * llm_graph_context::build_attn_mha(

cur = ggml_reshape_2d(ctx0, cur, n_embd_head_v*n_head, n_tokens);
} else {

// for MQA (ie: GQA with 1 group) we don't need to use a batched matrix multiply
if (n_head_kv == 1) {
q = ggml_view_2d(ctx0, q,
n_embd, n_tokens*n_head,
ggml_row_size(q->type, n_embd),
0);
}

ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);

// note: this op tends to require high floating point range
// while for some models F16 is enough, for others it is not, so we default to F32 here
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);

if (n_head_kv == 1) {
kq = ggml_view_3d(ctx0, kq,
n_kv, n_tokens, n_head,
ggml_row_size(kq->type, n_kv),
ggml_row_size(kq->type, n_kv)*n_tokens,
0);
}

if (arch == LLM_ARCH_GROK) {
// need to do the following:
// multiply by attn_output_multiplyer of 0.08838834764831845
Expand Down Expand Up @@ -1200,6 +1225,11 @@ ggml_tensor * llm_graph_context::build_attn_mha(

ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);

// for deepseek MLA we need to "decompress" from MQA back to MHA
if (v_mha_proj) {
kqv = ggml_mul_mat(ctx0, v_mha_proj, kqv);
}

ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);

cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens);
Expand Down Expand Up @@ -1258,7 +1288,7 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
//cb(k, "v", il);

ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, false, kq_scale);
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, nullptr, false, kq_scale);

cb(cur, "kqv_out", il);

Expand Down Expand Up @@ -1397,7 +1427,7 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_element_size(kv_self->v_l[il])*n_ctx*n_embd_head_v,
0);

ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_trans, kq_scale);
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, nullptr, v_trans, kq_scale);
cb(cur, "kqv_out", il);

if (wo) {
Expand Down Expand Up @@ -1456,8 +1486,101 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
//cb(k, "v", il);

ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, false, kq_scale);
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, nullptr, false, kq_scale);

cb(cur, "kqv_out", il);

if (wo) {
cur = build_lora_mm(wo, cur);
}

if (wo_b) {
//cb(cur, "kqv_wo", il);
}

if (wo_b) {
cur = ggml_add(ctx0, cur, wo_b);
}

return cur;
}

ggml_tensor * llm_graph_context::build_attn_mla(
llm_graph_input_attn_kv_unified * inp,
ggml_cgraph * gf,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * v_mha_proj,
ggml_tensor * q_cur,
ggml_tensor * k_cur,
ggml_tensor * v_cur,
ggml_tensor * kq_b,
float kq_scale,
int il) const {
// these nodes are added to the graph together so that they are not reordered
// by doing so, the number of splits in the graph is reduced
ggml_build_forward_expand(gf, q_cur);
ggml_build_forward_expand(gf, k_cur);
ggml_build_forward_expand(gf, v_cur);

const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
const auto & n_ctx = cparams.n_ctx;

const auto kv_lora_rank = hparams.n_lora_kv;

// note: deepseek with MLA option converts into MQA with larger n_ebed (ie: GQA with 1 group)
const int64_t n_embd_k_cmpr = kv_lora_rank + hparams.n_rot;
const int64_t n_embd_v_cmpr = kv_lora_rank;

// note: call from llm_build_deepseek2() passes as: {n_embd, n_tokens, n_head}
const auto n_tokens = q_cur->ne[1];

// store to KV cache
{
const auto kv_head = kv_self->head;

GGML_ASSERT(kv_self->size == n_ctx);

ggml_tensor * k_cache_view = ggml_view_1d(ctx0, kv_self->k_l[il],
n_tokens*n_embd_k_cmpr,
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_cmpr)*kv_head);
//cb(k_cache_view, "k_cache_view", il);

// note: storing RoPE-ed version of K in the KV cache
ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_cur, k_cache_view));

v_cur = ggml_reshape_2d(ctx0, v_cur, n_embd_v_cmpr, n_tokens);

ggml_tensor * v_cache_view = ggml_view_2d(ctx0, kv_self->v_l[il],
n_tokens, n_embd_v_cmpr,
( n_ctx)*ggml_element_size(kv_self->v_l[il]),
(kv_head)*ggml_element_size(kv_self->v_l[il]));

v_cur = ggml_transpose(ctx0, v_cur);
//cb(v_cache_view, "v_cache_view", il);

ggml_build_forward_expand(gf, ggml_cpy(ctx0, v_cur, v_cache_view));
}

const auto & kq_mask = inp->get_kq_mask();

const auto n_kv = kv_self->n;

ggml_tensor * k = ggml_view_3d(ctx0, kv_self->k_l[il],
n_embd_k_cmpr, n_kv, 1,
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_cmpr),
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_cmpr),
0);
//cb(k, "k", il);

struct ggml_tensor * v = ggml_view_3d(ctx0, kv_self->v_l[il],
n_kv, n_embd_v_cmpr, 1,
ggml_element_size(kv_self->v_l[il])*n_ctx,
ggml_element_size(kv_self->v_l[il])*n_ctx,
0);
//cb(v, "v", il);

ggml_tensor * cur = build_attn_mha(gf, q_cur, k, v, kq_b, kq_mask, v_mha_proj, true, kq_scale);
cb(cur, "kqv_out", il);

if (wo) {
Expand Down Expand Up @@ -1625,4 +1748,3 @@ void llm_graph_context::build_pooling(

ggml_build_forward_expand(gf, cur);
}

Loading
Loading