Skip to content

Commit b4c169f

Browse files
committed
Initial commit with all but the MLA graph code done
1 parent 833e2b7 commit b4c169f

15 files changed

+201
-24
lines changed

common/arg.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1346,6 +1346,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
13461346
params.flash_attn = true;
13471347
}
13481348
).set_env("LLAMA_ARG_FLASH_ATTN"));
1349+
add_opt(common_arg(
1350+
{"-mla", "--mla-attn"},
1351+
string_format("enable Multi-head Latent Attention (default: %s)", params.mla_attn ? "enabled" : "disabled"),
1352+
[](common_params & params) {
1353+
params.mla_attn = true;
1354+
}
1355+
).set_env("LLAMA_ARG_MLA_ATTN"));
13491356
add_opt(common_arg(
13501357
{"-p", "--prompt"}, "PROMPT",
13511358
"prompt to start generation with; for system message, use -sys",

common/common.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1098,6 +1098,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
10981098
cparams.cb_eval_user_data = params.cb_eval_user_data;
10991099
cparams.offload_kqv = !params.no_kv_offload;
11001100
cparams.flash_attn = params.flash_attn;
1101+
cparams.mla_attn = params.mla_attn;
11011102
cparams.no_perf = params.no_perf;
11021103

11031104
if (params.reranking) {

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,7 @@ struct common_params {
319319
bool simple_io = false; // improves compatibility with subprocesses and limited consoles
320320
bool cont_batching = true; // insert new sequences for decoding on-the-fly
321321
bool flash_attn = false; // flash attention
322+
bool mla_attn = false; // MLA attention for deepseek2
322323
bool no_perf = false; // disable performance metrics
323324
bool ctx_shift = true; // context shift on inifinite text generation
324325

convert_hf_to_gguf.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,7 @@ def prepare_tensors(self):
330330
gguf.MODEL_TENSOR.TIME_MIX_LERP_FUSED,
331331
gguf.MODEL_TENSOR.POSNET_NORM1,
332332
gguf.MODEL_TENSOR.POSNET_NORM2,
333+
gguf.MODEL_TENSOR.ATTN_K_B,
333334
)
334335
)
335336
or not new_name.endswith(".weight")
@@ -4414,6 +4415,27 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
44144415
return []
44154416

44164417
return [(self.map_tensor_name(name), data_torch)]
4418+
if name.endswith("kv_b_proj.weight"):
4419+
name_kb = name.replace("kv_b_proj", "k_b_proj")
4420+
name_vb = name.replace("kv_b_proj", "v_b_proj")
4421+
4422+
n_head_kv = self.hparams["num_key_value_heads"]
4423+
v_head_dim = self.hparams["v_head_dim"]
4424+
qk_nope_head_dim = self.hparams["qk_nope_head_dim"]
4425+
4426+
assert data_torch.shape[0] == n_head_kv * (v_head_dim + qk_nope_head_dim)
4427+
4428+
kv_b = data_torch.view(n_head_kv, v_head_dim + qk_nope_head_dim, data_torch.shape[-1])
4429+
k_b, v_b = torch.split(kv_b, [qk_nope_head_dim, v_head_dim], dim=1)
4430+
k_b = k_b.transpose(1, 2)
4431+
k_b = k_b.reshape(n_head_kv * data_torch.shape[-1], qk_nope_head_dim)
4432+
v_b = v_b.reshape(n_head_kv * v_head_dim, data_torch.shape[-1])
4433+
4434+
return [
4435+
(self.map_tensor_name(name), data_torch),
4436+
(self.map_tensor_name(name_kb), k_b),
4437+
(self.map_tensor_name(name_vb), v_b)
4438+
]
44174439

44184440
def prepare_tensors(self):
44194441
super().prepare_tensors()

examples/server/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ The project is under active development, and we are [looking for feedback and co
4646
| `-ub, --ubatch-size N` | physical maximum batch size (default: 512)<br/>(env: LLAMA_ARG_UBATCH) |
4747
| `--keep N` | number of tokens to keep from the initial prompt (default: 0, -1 = all) |
4848
| `-fa, --flash-attn` | enable Flash Attention (default: disabled)<br/>(env: LLAMA_ARG_FLASH_ATTN) |
49+
| `-mla, --mla-attn` | enable Multi-head Latent Attention (default: disabled)<br/>(env: LLAMA_ARG_MLA_ATTN) |
4950
| `--no-perf` | disable internal libllama performance timings (default: false)<br/>(env: LLAMA_ARG_NO_PERF) |
5051
| `-e, --escape` | process escapes sequences (\n, \r, \t, \', \", \\) (default: true) |
5152
| `--no-escape` | do not process escape sequences |

gguf-py/gguf/constants.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,8 @@ class MODEL_TENSOR(IntEnum):
377377
ATTN_Q_B = auto()
378378
ATTN_KV_A_MQA = auto()
379379
ATTN_KV_B = auto()
380+
ATTN_K_B = auto()
381+
ATTN_V_B = auto()
380382
ATTN_Q_A_NORM = auto()
381383
ATTN_KV_A_NORM = auto()
382384
FFN_SUB_NORM = auto()
@@ -581,6 +583,8 @@ class MODEL_TENSOR(IntEnum):
581583
MODEL_TENSOR.ATTN_Q_B: "blk.{bid}.attn_q_b",
582584
MODEL_TENSOR.ATTN_KV_A_MQA: "blk.{bid}.attn_kv_a_mqa",
583585
MODEL_TENSOR.ATTN_KV_B: "blk.{bid}.attn_kv_b",
586+
MODEL_TENSOR.ATTN_K_B: "blk.{bid}.attn_k_b",
587+
MODEL_TENSOR.ATTN_V_B: "blk.{bid}.attn_v_b",
584588
MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm",
585589
MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm",
586590
MODEL_TENSOR.ATTN_SUB_NORM: "blk.{bid}.attn_sub_norm",
@@ -1451,6 +1455,8 @@ class MODEL_TENSOR(IntEnum):
14511455
MODEL_TENSOR.ATTN_Q_B,
14521456
MODEL_TENSOR.ATTN_KV_A_MQA,
14531457
MODEL_TENSOR.ATTN_KV_B,
1458+
MODEL_TENSOR.ATTN_K_B,
1459+
MODEL_TENSOR.ATTN_V_B,
14541460
MODEL_TENSOR.ATTN_Q_A_NORM,
14551461
MODEL_TENSOR.ATTN_KV_A_NORM,
14561462
MODEL_TENSOR.ATTN_OUT,

gguf-py/gguf/tensor_mapping.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -656,6 +656,14 @@ class TensorNameMap:
656656
"model.layers.{bid}.self_attn.kv_b_proj", # deepseek2
657657
),
658658

659+
MODEL_TENSOR.ATTN_K_B: (
660+
"model.layers.{bid}.self_attn.k_b_proj", # deepseek2 (MLA specific)
661+
),
662+
663+
MODEL_TENSOR.ATTN_V_B: (
664+
"model.layers.{bid}.self_attn.v_b_proj", # deepseek2 (MLA specific)
665+
),
666+
659667
MODEL_TENSOR.ATTN_Q_A_NORM: (
660668
"model.layers.{bid}.self_attn.q_a_layernorm", # deepseek2
661669
),

include/llama.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,7 @@ extern "C" {
355355
bool embeddings; // if true, extract embeddings (together with logits)
356356
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
357357
bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
358+
bool mla_attn; // MLA attention for deepseek2
358359
bool no_perf; // whether to measure performance timings
359360

360361
// Abort callback

src/llama-arch.cpp

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1030,6 +1030,8 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
10301030
{ LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" },
10311031
{ LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" },
10321032
{ LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" },
1033+
{ LLM_TENSOR_ATTN_K_B, "blk.%d.attn_k_b" },
1034+
{ LLM_TENSOR_ATTN_V_B, "blk.%d.attn_v_b" },
10331035
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
10341036
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
10351037
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
@@ -1471,23 +1473,8 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
14711473
{LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
14721474
{LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
14731475
{LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1474-
{LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1475-
{LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1476-
{LLM_TENSOR_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1477-
{LLM_TENSOR_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1478-
{LLM_TENSOR_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1479-
{LLM_TENSOR_ATTN_QKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1480-
{LLM_TENSOR_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1481-
{LLM_TENSOR_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1482-
{LLM_TENSOR_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1483-
{LLM_TENSOR_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1484-
{LLM_TENSOR_FFN_DOWN_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1485-
{LLM_TENSOR_FFN_GATE_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1486-
{LLM_TENSOR_FFN_UP_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1487-
{LLM_TENSOR_ATTN_Q_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1488-
{LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1489-
{LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1490-
{LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1476+
{LLM_TENSOR_ATTN_K_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1477+
{LLM_TENSOR_ATTN_V_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
14911478
{LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
14921479
{LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
14931480
{LLM_TENSOR_DEC_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},

src/llama-arch.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,8 @@ enum llm_tensor {
299299
LLM_TENSOR_ATTN_Q_B,
300300
LLM_TENSOR_ATTN_KV_A_MQA,
301301
LLM_TENSOR_ATTN_KV_B,
302+
LLM_TENSOR_ATTN_K_B,
303+
LLM_TENSOR_ATTN_V_B,
302304
LLM_TENSOR_ATTN_Q_A_NORM,
303305
LLM_TENSOR_ATTN_KV_A_NORM,
304306
LLM_TENSOR_ATTN_SUB_NORM,

src/llama-context.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ llama_context::llama_context(
3737
cparams.embeddings = params.embeddings;
3838
cparams.offload_kqv = params.offload_kqv;
3939
cparams.flash_attn = params.flash_attn;
40+
cparams.mla_attn = params.mla_attn;
4041
cparams.no_perf = params.no_perf;
4142
cparams.pooling_type = params.pooling_type;
4243
cparams.warmup = false;
@@ -104,6 +105,7 @@ llama_context::llama_context(
104105
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
105106
LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn);
106107
LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn);
108+
LLAMA_LOG_INFO("%s: mla_attn = %d\n", __func__, cparams.mla_attn);
107109
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
108110
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
109111

@@ -2243,6 +2245,7 @@ llama_context_params llama_context_default_params() {
22432245
/*.embeddings =*/ false,
22442246
/*.offload_kqv =*/ true,
22452247
/*.flash_attn =*/ false,
2248+
/*.mla_attn =*/ false,
22462249
/*.no_perf =*/ true,
22472250
/*.abort_callback =*/ nullptr,
22482251
/*.abort_callback_data =*/ nullptr,
@@ -2274,6 +2277,11 @@ llama_context * llama_init_from_model(
22742277
params.flash_attn = false;
22752278
}
22762279

2280+
if (params.mla_attn && model->arch != LLM_ARCH_DEEPSEEK2) {
2281+
LLAMA_LOG_WARN("%s: mla_attn is only compatible with Deepseek2 - forcing off\n", __func__);
2282+
params.mla_attn = false;
2283+
}
2284+
22772285
if (ggml_is_quantized(params.type_v) && !params.flash_attn) {
22782286
LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__);
22792287
return nullptr;

src/llama-cparams.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ struct llama_cparams {
2828
bool causal_attn;
2929
bool offload_kqv;
3030
bool flash_attn;
31+
bool mla_attn;
3132
bool no_perf;
3233
bool warmup;
3334

src/llama-kv-cache.cpp

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ bool llama_kv_cache_unified::init(
2727

2828
recurrent = llama_model_is_recurrent(&model);
2929
v_trans = !recurrent && !cparams.flash_attn;
30-
can_shift = !recurrent && model.arch != LLM_ARCH_DEEPSEEK2; // not supported due to MLA
30+
can_shift = !recurrent && model.arch != LLM_ARCH_DEEPSEEK2; // TODO: support DEEPSEEK2 context shifting
3131

3232
LLAMA_LOG_INFO("%s: kv_size = %d, offload = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d\n",
3333
__func__, kv_size, offload, ggml_type_name(type_k), ggml_type_name(type_v), n_layer, can_shift);
@@ -71,8 +71,17 @@ bool llama_kv_cache_unified::init(
7171
v_l.reserve(n_layer);
7272

7373
for (int i = 0; i < n_layer; i++) {
74-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
75-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
74+
int64_t n_embd_k;
75+
int64_t n_embd_v;
76+
77+
// note: deepseek with MLA option converts into MQA (ie: GQA with 1 group)
78+
if (cparams.mla_attn) {
79+
n_embd_k = hparams.n_lora_kv + hparams.n_rot;
80+
n_embd_v = hparams.n_lora_kv;
81+
} else {
82+
n_embd_k = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
83+
n_embd_v = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
84+
}
7685

7786
const char * dev_name = "CPU";
7887

@@ -86,17 +95,17 @@ bool llama_kv_cache_unified::init(
8695
buft = ggml_backend_cpu_buffer_type();
8796
}
8897

89-
LLAMA_LOG_DEBUG("%s: layer %3d: n_embd_k_gqa = %d, n_embd_v_gqa = %d, dev = %s\n", __func__,
90-
i, n_embd_k_gqa, n_embd_v_gqa, dev_name);
98+
LLAMA_LOG_DEBUG("%s: layer %3d: n_embd_k = %d, n_embd_v = %d, dev = %s\n", __func__,
99+
i, n_embd_k, n_embd_v, dev_name);
91100

92101
ggml_context * ctx = ctx_for_buft(buft);
93102
if (!ctx) {
94103
LLAMA_LOG_ERROR("%s: failed to create ggml context for kv cache\n", __func__);
95104
return false;
96105
}
97106

98-
ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
99-
ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
107+
ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k*kv_size);
108+
ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v*kv_size);
100109
ggml_format_name(k, "cache_k_l%d", i);
101110
ggml_format_name(v, "cache_v_l%d", i);
102111
k_l.push_back(k);

0 commit comments

Comments
 (0)