Skip to content

DeepSeek V2/V3 with -mla option (final) #12772

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
b4c169f
Initial commit with all but the MLA graph code done
jukofyork Apr 2, 2025
10207b4
Fixes
jukofyork Apr 2, 2025
ea3c05b
Just make `uint32_t n_embd_k` and `uint32_t n_embd_v`
jukofyork Apr 2, 2025
1f604a7
First working version
jukofyork Apr 2, 2025
1de077b
Fixed return bug in `DeepseekV2Model`
jukofyork Apr 2, 2025
7f92e7b
Minor fixes
jukofyork Apr 2, 2025
319e3ef
More fixes
jukofyork Apr 2, 2025
ee4b389
Renamed `wv_b` to `wv_decompress` to avoid confusion with `_b` biases
jukofyork Apr 2, 2025
c00cd9e
Better `_compressed` variable names
jukofyork Apr 2, 2025
55ad3a7
Minor comment and variable name fixes
jukofyork Apr 2, 2025
0c86f56
Moved `build_attn_mla` to better location
jukofyork Apr 2, 2025
b0c8a43
Removed `gguf.MODEL_TENSOR.ATTN_K_B` from `prepare_tensors()` for now
jukofyork Apr 2, 2025
8c329bc
Bumped `wkv_b` and `wk_b` to use F32.
jukofyork Apr 2, 2025
68302ee
Use `ggml_mul_mat_set_prec` `GGML_PREC_F32` by default for now
jukofyork Apr 2, 2025
937a48d
Better/shorter variable names and more tidying up of code
jukofyork Apr 2, 2025
1fd0aab
Fixed `kv_cmpr_pe` name
jukofyork Apr 2, 2025
4fb439f
Added `n_embd_head_k` as constant
jukofyork Apr 2, 2025
f9a0ef4
Fixed to use `build_attn_mha()` now
jukofyork Apr 3, 2025
5fe402a
`mla_attn` on then not `flash_attn` so we can run `-fa` for draft models
jukofyork Apr 3, 2025
9b862f9
"flash_attn is not compatible with mla_attn" --> flash_attn off
jukofyork Apr 3, 2025
8e23e0d
Fixed subtle bug caused by `-mla` for speculative models
jukofyork Apr 3, 2025
b384086
Removed need for `v_b_proj` storing. Tidied all ggml_row_size for quants
jukofyork Apr 4, 2025
5dbf99c
Removed both calls to `ggml_mul_mat_set_prec` for MLA and non-MLA cases
jukofyork Apr 4, 2025
01a612a
Version that has `wk_b_trans` and `wv_b` forced to use `F32`
jukofyork Apr 5, 2025
c0ffed5
Merge branch 'ggml-org:master' into mainline-llama-cpp-master--mla--f32
jukofyork Apr 5, 2025
8d12c42
Reverted to use `wk_b_trans` and `wv_b`. Documented `build_attn_mla`
jukofyork Apr 5, 2025
997a4a8
Split into k_b_proj, k_b_proj_trans and v_b_proj
jukofyork Apr 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1346,6 +1346,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.flash_attn = true;
}
).set_env("LLAMA_ARG_FLASH_ATTN"));
add_opt(common_arg(
{"-mla", "--mla-attn"},
string_format("enable Multi-head Latent Attention (default: %s)", params.mla_attn ? "enabled" : "disabled"),
[](common_params & params) {
params.mla_attn = true;
}
).set_env("LLAMA_ARG_MLA_ATTN"));
add_opt(common_arg(
{"-p", "--prompt"}, "PROMPT",
"prompt to start generation with; for system message, use -sys",
Expand Down
1 change: 1 addition & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1098,6 +1098,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
cparams.cb_eval_user_data = params.cb_eval_user_data;
cparams.offload_kqv = !params.no_kv_offload;
cparams.flash_attn = params.flash_attn;
cparams.mla_attn = params.mla_attn;
cparams.no_perf = params.no_perf;

if (params.reranking) {
Expand Down
1 change: 1 addition & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,7 @@ struct common_params {
bool simple_io = false; // improves compatibility with subprocesses and limited consoles
bool cont_batching = true; // insert new sequences for decoding on-the-fly
bool flash_attn = false; // flash attention
bool mla_attn = false; // MLA attention for deepseek2
bool no_perf = false; // disable performance metrics
bool ctx_shift = true; // context shift on inifinite text generation

Expand Down
24 changes: 24 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4413,6 +4413,30 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
else:
return []

if name.endswith("kv_b_proj.weight"):
name_kb = name.replace("kv_b_proj", "k_b_proj")
name_kb_trans = name.replace("kv_b_proj", "k_b_proj_trans")
name_vb = name.replace("kv_b_proj", "v_b_proj")

n_head_kv = self.hparams["num_key_value_heads"]
v_head_dim = self.hparams["v_head_dim"]
qk_nope_head_dim = self.hparams["qk_nope_head_dim"]

assert data_torch.shape[0] == n_head_kv * (v_head_dim + qk_nope_head_dim)

kv_b = data_torch.view(n_head_kv, v_head_dim + qk_nope_head_dim, data_torch.shape[-1])
k_b, v_b = torch.split(kv_b, [qk_nope_head_dim, v_head_dim], dim=1)
k_b_trans = k_b.transpose(1, 2)
k_b = k_b.reshape(n_head_kv * qk_nope_head_dim, data_torch.shape[-1])
k_b_trans = k_b_trans.reshape(n_head_kv * data_torch.shape[-1], qk_nope_head_dim)
v_b = v_b.reshape(n_head_kv * v_head_dim, data_torch.shape[-1])

return [
(self.map_tensor_name(name_kb), k_b),
(self.map_tensor_name(name_kb_trans), k_b_trans),
(self.map_tensor_name(name_vb), v_b)
]

return [(self.map_tensor_name(name), data_torch)]

def prepare_tensors(self):
Expand Down
1 change: 1 addition & 0 deletions examples/server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ The project is under active development, and we are [looking for feedback and co
| `-ub, --ubatch-size N` | physical maximum batch size (default: 512)<br/>(env: LLAMA_ARG_UBATCH) |
| `--keep N` | number of tokens to keep from the initial prompt (default: 0, -1 = all) |
| `-fa, --flash-attn` | enable Flash Attention (default: disabled)<br/>(env: LLAMA_ARG_FLASH_ATTN) |
| `-mla, --mla-attn` | enable Multi-head Latent Attention (default: disabled)<br/>(env: LLAMA_ARG_MLA_ATTN) |
| `--no-perf` | disable internal libllama performance timings (default: false)<br/>(env: LLAMA_ARG_NO_PERF) |
| `-e, --escape` | process escapes sequences (\n, \r, \t, \', \", \\) (default: true) |
| `--no-escape` | do not process escape sequences |
Expand Down
9 changes: 9 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,9 @@ class MODEL_TENSOR(IntEnum):
ATTN_Q_B = auto()
ATTN_KV_A_MQA = auto()
ATTN_KV_B = auto()
ATTN_K_B = auto()
ATTN_K_B_TRANS = auto()
ATTN_V_B = auto()
ATTN_Q_A_NORM = auto()
ATTN_KV_A_NORM = auto()
FFN_SUB_NORM = auto()
Expand Down Expand Up @@ -581,6 +584,9 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.ATTN_Q_B: "blk.{bid}.attn_q_b",
MODEL_TENSOR.ATTN_KV_A_MQA: "blk.{bid}.attn_kv_a_mqa",
MODEL_TENSOR.ATTN_KV_B: "blk.{bid}.attn_kv_b",
MODEL_TENSOR.ATTN_K_B: "blk.{bid}.attn_k_b",
MODEL_TENSOR.ATTN_K_B_TRANS: "blk.{bid}.attn_k_b_trans",
MODEL_TENSOR.ATTN_V_B: "blk.{bid}.attn_v_b",
MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm",
MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm",
MODEL_TENSOR.ATTN_SUB_NORM: "blk.{bid}.attn_sub_norm",
Expand Down Expand Up @@ -1451,6 +1457,9 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.ATTN_Q_B,
MODEL_TENSOR.ATTN_KV_A_MQA,
MODEL_TENSOR.ATTN_KV_B,
MODEL_TENSOR.ATTN_K_B,
MODEL_TENSOR.ATTN_K_B_TRANS,
MODEL_TENSOR.ATTN_V_B,
MODEL_TENSOR.ATTN_Q_A_NORM,
MODEL_TENSOR.ATTN_KV_A_NORM,
MODEL_TENSOR.ATTN_OUT,
Expand Down
12 changes: 12 additions & 0 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,18 @@ class TensorNameMap:
"model.layers.{bid}.self_attn.kv_b_proj", # deepseek2
),

MODEL_TENSOR.ATTN_K_B: (
"model.layers.{bid}.self_attn.k_b_proj", # deepseek2
),

MODEL_TENSOR.ATTN_K_B_TRANS: (
"model.layers.{bid}.self_attn.k_b_proj_trans", # deepseek2 (MLA specific)
),

MODEL_TENSOR.ATTN_V_B: (
"model.layers.{bid}.self_attn.v_b_proj", # deepseek2
),

MODEL_TENSOR.ATTN_Q_A_NORM: (
"model.layers.{bid}.self_attn.q_a_layernorm", # deepseek2
),
Expand Down
1 change: 1 addition & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,7 @@ extern "C" {
bool embeddings; // if true, extract embeddings (together with logits)
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
bool mla_attn; // MLA attention for deepseek2
bool no_perf; // whether to measure performance timings

// Abort callback
Expand Down
23 changes: 6 additions & 17 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1030,6 +1030,9 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" },
{ LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" },
{ LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" },
{ LLM_TENSOR_ATTN_K_B, "blk.%d.attn_k_b" },
{ LLM_TENSOR_ATTN_K_B_TRANS, "blk.%d.attn_k_b_trans" },
{ LLM_TENSOR_ATTN_V_B, "blk.%d.attn_v_b" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
Expand Down Expand Up @@ -1471,23 +1474,9 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_QKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_DOWN_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_GATE_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_UP_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_Q_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_K_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_K_B_TRANS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_V_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_DEC_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
Expand Down
3 changes: 3 additions & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,9 @@ enum llm_tensor {
LLM_TENSOR_ATTN_Q_B,
LLM_TENSOR_ATTN_KV_A_MQA,
LLM_TENSOR_ATTN_KV_B,
LLM_TENSOR_ATTN_K_B,
LLM_TENSOR_ATTN_K_B_TRANS,
LLM_TENSOR_ATTN_V_B,
LLM_TENSOR_ATTN_Q_A_NORM,
LLM_TENSOR_ATTN_KV_A_NORM,
LLM_TENSOR_ATTN_SUB_NORM,
Expand Down
16 changes: 16 additions & 0 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ llama_context::llama_context(
cparams.embeddings = params.embeddings;
cparams.offload_kqv = params.offload_kqv;
cparams.flash_attn = params.flash_attn;
cparams.mla_attn = params.mla_attn;
cparams.no_perf = params.no_perf;
cparams.pooling_type = params.pooling_type;
cparams.warmup = false;
Expand Down Expand Up @@ -104,6 +105,7 @@ llama_context::llama_context(
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn);
LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn);
LLAMA_LOG_INFO("%s: mla_attn = %d\n", __func__, cparams.mla_attn);
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);

Expand Down Expand Up @@ -2243,6 +2245,7 @@ llama_context_params llama_context_default_params() {
/*.embeddings =*/ false,
/*.offload_kqv =*/ true,
/*.flash_attn =*/ false,
/*.mla_attn =*/ false,
/*.no_perf =*/ true,
/*.abort_callback =*/ nullptr,
/*.abort_callback_data =*/ nullptr,
Expand Down Expand Up @@ -2274,6 +2277,19 @@ llama_context * llama_init_from_model(
params.flash_attn = false;
}

if (params.mla_attn) {
if (model->arch != LLM_ARCH_DEEPSEEK2) {
LLAMA_LOG_WARN("%s: mla_attn is only compatible with Deepseek2 - forcing off\n", __func__);
params.mla_attn = false;
} else if (model->layers[0].wk_b_trans == nullptr || model->layers[0].wv_b == nullptr) {
LLAMA_LOG_WARN("%s: mla_attn requires a gguf with the new 'attn_k_b_trans' and 'attn_v_b' tensors - forcing off\n", __func__);
params.mla_attn = false;
} else if (params.flash_attn) {
LLAMA_LOG_WARN("%s: flash_attn is not compatible with mla_attn - forcing off\n", __func__);
params.flash_attn = false;
}
}

if (ggml_is_quantized(params.type_v) && !params.flash_attn) {
LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__);
return nullptr;
Expand Down
1 change: 1 addition & 0 deletions src/llama-cparams.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ struct llama_cparams {
bool causal_attn;
bool offload_kqv;
bool flash_attn;
bool mla_attn;
bool no_perf;
bool warmup;

Expand Down
Loading
Loading