Skip to content

Commit cbe65a9

Browse files
authored
bugfix: fix MLA with new JIT pipeline (#620)
Some of the commits for fixing MLA are missing in #618, this PR add them back.
1 parent b27a2cc commit cbe65a9

File tree

4 files changed

+42
-25
lines changed

4 files changed

+42
-25
lines changed

python/flashinfer/decode.py

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1270,6 +1270,8 @@ def plan(
12701270
q_data_type = data_type
12711271
q_data_type = canonicalize_torch_dtype(q_data_type)
12721272

1273+
indptr_host = indptr.to("cpu")
1274+
12731275
self._cached_module = get_batch_decode_mla_module(
12741276
q_data_type,
12751277
data_type,
@@ -1284,7 +1286,7 @@ def plan(
12841286
self._float_workspace_buffer,
12851287
self._int_workspace_buffer,
12861288
self._pin_memory_int_workspace_buffer,
1287-
indptr,
1289+
indptr_host,
12881290
batch_size,
12891291
num_qo_heads,
12901292
page_size,
@@ -1357,24 +1359,36 @@ def run(
13571359
if rope_theta is None:
13581360
rope_theta = 1e4
13591361

1360-
out = self._cached_module.run(
1361-
self._float_workspace_buffer,
1362-
self._int_workspace_buffer,
1363-
self._plan_info,
1364-
q_nope,
1365-
q_pe,
1366-
paged_ckv_cache,
1367-
paged_kpe_cache,
1368-
self._paged_kv_indptr_buf,
1369-
self._paged_kv_indices_buf,
1370-
self._paged_kv_last_page_len_buf,
1371-
sm_scale,
1372-
window_left,
1373-
logits_soft_cap,
1374-
rope_scale,
1375-
rope_theta,
1376-
return_lse,
1377-
)
1362+
with self.device as device:
1363+
o = torch.empty_like(q_nope, device=device)
1364+
maybe_lse = (
1365+
torch.empty(
1366+
(q_nope.size(0), q_nope.size(1)), dtype=torch.float32, device=device
1367+
)
1368+
if return_lse
1369+
else None
1370+
)
1371+
self._cached_module.run(
1372+
self._float_workspace_buffer,
1373+
self._int_workspace_buffer,
1374+
self._plan_info,
1375+
q_nope,
1376+
q_pe,
1377+
paged_ckv_cache,
1378+
paged_kpe_cache,
1379+
self._paged_kv_indptr_buf,
1380+
self._paged_kv_indices_buf,
1381+
self._paged_kv_last_page_len_buf,
1382+
o,
1383+
sm_scale,
1384+
window_left,
1385+
logits_soft_cap,
1386+
rope_scale,
1387+
rope_theta,
1388+
maybe_lse,
1389+
get_cuda_stream(device),
1390+
)
1391+
out = (o, maybe_lse) if return_lse else (o,)
13781392
if v_scale is not None:
13791393
out[0] *= v_scale
13801394

python/flashinfer/jit/attention.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,9 @@ def get_batch_decode_mla_sources(
183183
"dtype_kv": dtype_map[dtype_kv],
184184
"dtype_o": dtype_map[dtype_o],
185185
"dtype_idx": dtype_map[dtype_idx],
186-
"head_dim": head_dim,
186+
"head_dim_ckv": head_dim,
187+
"head_dim_kpe": head_dim
188+
// 8, # fixme: head_dim_ckv(kv_lora_rank) is 8 times the size of head_dim_kpe(qk_rope_head_dim) for all MLA model (DeepSeek-V2-Lite, DeepSeek-V2.5, MiniCPM3) at the time Oct.2024
187189
"use_sliding_window": "true" if use_sliding_window else "false",
188190
"use_logits_soft_cap": "true" if use_logits_soft_cap else "false",
189191
},

python/flashinfer/jit/batch_decode_mla_templ.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,8 @@
106106
107107
if (maybe_lse) {
108108
const auto& lse = *maybe_lse;
109-
TORCH_CHECK(lse.size(0) == batch_size, lse.size(0), q.size(0));
110-
TORCH_CHECK(lse.size(1) == num_qo_heads, lse.size(1), q.size(1));
109+
TORCH_CHECK(lse.size(0) == batch_size, lse.size(0), q_nope.size(0));
110+
TORCH_CHECK(lse.size(1) == num_qo_heads, lse.size(1), q_nope.size(1));
111111
}
112112
113113
TORCH_CHECK(logits_soft_cap >= 0.f, "logits_soft_cap must be non-negative");
@@ -146,9 +146,10 @@
146146
}
147147
params.padded_batch_size = plan_info.padded_batch_size;
148148
149+
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
149150
cudaError_t status = BatchDecodeWithPagedKVCacheDispatchedMLA<
150151
{{ head_dim_ckv }}, {{ head_dim_kpe }}, AttentionVariant>(
151-
params, tmp_v, tmp_s, /*stream=*/torch_current_stream);
152+
params, tmp_v, tmp_s, /*stream=*/stream);
152153
TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPagedKVCache failed with error ",
153154
cudaGetErrorString(status));
154155
}

tests/test_mla_decode_kernel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ def run_proof_of_concept(
367367

368368
dev_id = 1
369369

370-
# torch.manual_seed(666)
370+
torch.manual_seed(666)
371371
torch.set_grad_enabled(False)
372372

373373
mla_vanilla = DeepseekV2AttentionVanilla().cuda(device=dev_id)
@@ -436,7 +436,7 @@ def run_proof_of_concept(
436436
output_vanilla.reshape(-1), output_mat_absorbed_use_torch_f16.reshape(-1)
437437
)
438438
print(f"wmape_use_torch_f16 = {wmape_use_torch_f16}")
439-
assert wmape_use_torch_f16 < 0.02
439+
assert wmape_use_torch_f16 < 0.03
440440

441441
mse_use_torch_f16 = F.mse_loss(
442442
output_vanilla.reshape(-1), output_mat_absorbed_use_torch_f16.reshape(-1)

0 commit comments

Comments
 (0)