Skip to content

Commit 6ec3bae

Browse files
authored
bugfix: fix the behavior of MLA kernel when kv-length is 0 (#868)
The scheduling algorithm in #863 do not consider some requests have kv-cache length 0, this PR fixes the issue.
1 parent 7cd000b commit 6ec3bae

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

include/flashinfer/attention/scheduler.cuh

+3-1
Original file line numberDiff line numberDiff line change
@@ -1134,7 +1134,8 @@ inline cudaError_t MLAPlan(void* float_buffer, size_t float_workspace_size_in_by
11341134
qo_indptr_h[i] * num_heads +
11351135
std::min((qo_tile_idx + 1) * cluster_tile_q, packed_qo_len);
11361136
}
1137-
while (remaining_len > 0) {
1137+
bool zero_kv_len = (remaining_len == 0);
1138+
while (remaining_len > 0 || zero_kv_len) {
11381139
auto [cluster_idx, accum_cost] = cluster_cost_heap.pop();
11391140
int actual_len = std::min(remaining_len, kv_len_limit);
11401141
cluster_cost_heap.insert(
@@ -1154,6 +1155,7 @@ inline cudaError_t MLAPlan(void* float_buffer, size_t float_workspace_size_in_by
11541155
cluster_kv_end[cluster_idx].push_back(kv_start + actual_len);
11551156
remaining_len -= actual_len;
11561157
kv_start += actual_len;
1158+
if (zero_kv_len) break;
11571159
}
11581160
split_kv_count += int(split_kv);
11591161
}

tests/test_deepseek_mla.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def generate_kv_from_cache(ckv, kpe, kv_len, batch_size, num_heads):
171171

172172

173173
@pytest.mark.parametrize("batch_size", [1, 17, 37])
174-
@pytest.mark.parametrize("kv_len", [17, 33, 96, 97, 114, 514, 1024])
174+
@pytest.mark.parametrize("kv_len", [0, 17, 33, 96, 97, 114, 514, 1024])
175175
@pytest.mark.parametrize("qo_len", [1, 17, 37, 77])
176176
@pytest.mark.parametrize("num_heads", [4, 32, 128])
177177
@pytest.mark.parametrize("causal", [False, True])
@@ -243,7 +243,8 @@ def test_batch_mla_page_attention(
243243
o_ref, lse_ref = attention_ref(batch_size, q, k, v, causal, sm_scale)
244244
lse_ref = lse_ref.flatten(0, 1)
245245
torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3)
246-
torch.testing.assert_close(lse, lse_ref, rtol=1e-3, atol=1e-3)
246+
if kv_len != 0:
247+
torch.testing.assert_close(lse, lse_ref, rtol=1e-3, atol=1e-3)
247248

248249
# test with pre-allocated output
249250
o_buffer = torch.empty_like(o)

0 commit comments

Comments
 (0)