Skip to content

unittest: add unittests for MLA + cudagraph #890

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/flashinfer/attention/scheduler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1098,7 +1098,7 @@ inline cudaError_t MLAPlan(void* float_buffer, size_t float_workspace_size_in_by
total_kv_lens += effective_kv_len;
}
}
int kv_len_limit = ceil_div(ceil_div(total_kv_lens, num_clusters), 512L) * 512L;
int kv_len_limit = ceil_div(std::max(ceil_div(total_kv_lens, num_clusters), 1L), 512L) * 512L;

// step 1. load-balancing scheduling algorithm
MinHeap cluster_cost_heap(num_clusters);
Expand Down
57 changes: 52 additions & 5 deletions tests/test_deepseek_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def generate_kv_from_cache(ckv, kpe, kv_len, batch_size, num_heads):
@pytest.mark.parametrize("num_heads", [16, 32, 64])
@pytest.mark.parametrize("causal", [False, True])
@pytest.mark.parametrize("page_size", [1])
@pytest.mark.parametrize("backend", ["fa2", "fa3"])
@pytest.mark.parametrize("backend", ["fa3"])
@pytest.mark.parametrize("dtype", [torch.half])
def test_batch_mla_varlen_page_attention(
batch_size,
Expand Down Expand Up @@ -311,15 +311,16 @@ def test_batch_mla_varlen_page_attention(
# torch.testing.assert_close(lse_i, lse_ref, rtol=1e-3, atol=1e-3)


@pytest.mark.parametrize("batch_size", [1, 2, 3, 4, 5, 6, 7])
@pytest.mark.parametrize("batch_size", [1, 2, 3, 4, 5, 6, 7, 157])
@pytest.mark.parametrize("kv_len", [0, 17, 33, 96, 97, 114, 514, 1024])
@pytest.mark.parametrize(
"qo_len", [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17]
)
@pytest.mark.parametrize("num_heads", [16])
@pytest.mark.parametrize("causal", [False, True])
@pytest.mark.parametrize("page_size", [1])
@pytest.mark.parametrize("page_size", [1, 16])
@pytest.mark.parametrize("backend", ["fa2", "fa3"])
@pytest.mark.parametrize("use_cuda_graph", [True, False])
@pytest.mark.parametrize("dtype", [torch.half])
def test_batch_mla_page_attention(
batch_size,
Expand All @@ -329,6 +330,7 @@ def test_batch_mla_page_attention(
causal,
page_size,
backend,
use_cuda_graph,
dtype,
):
if not mla_is_fa3_supported(torch.device("cuda")):
Expand Down Expand Up @@ -362,12 +364,51 @@ def test_batch_mla_page_attention(
sm_scale = 1.0 / ((128 + 64) ** 0.5) # use head dimension before matrix absorption
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0)
wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(
workspace_buffer, backend=backend
workspace_buffer,
backend=backend,
use_cuda_graph=True,
qo_indptr=torch.empty(batch_size + 1, dtype=torch.int32, device="cuda"),
kv_indptr=torch.empty(batch_size + 1, dtype=torch.int32, device="cuda"),
kv_indices=torch.empty(1048576, dtype=torch.int32, device="cuda"),
kv_len_arr=torch.empty(batch_size, dtype=torch.int32, device="cuda"),
)
q_indptr = torch.arange(0, batch_size + 1).to(0).int() * qo_len
kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * pages_num
kv_indices = torch.arange(0, batch_size * pages_num).to(0).int()
kv_lens = torch.full((batch_size,), kv_len, dtype=torch.int32).to(0)

if use_cuda_graph:
kv_indptr_warmup = torch.zeros(batch_size + 1).to(0).int()
kv_indices_warmup = torch.arange(0, batch_size).to(0).int()
kv_lens_warmup = torch.full((batch_size,), 0, dtype=torch.int32).to(0)
wrapper.plan(
q_indptr,
kv_indptr_warmup,
kv_indices_warmup,
kv_lens_warmup,
num_heads,
head_dim_ckv,
head_dim_kpe,
page_size,
causal,
sm_scale,
q_nope.dtype,
ckv.dtype,
)

# warmup
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(3):
o, lse = wrapper.run(q_nope, q_pe, ckv, kpe, return_lse=True)
torch.cuda.current_stream().wait_stream(s)

# capture
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
o, lse = wrapper.run(q_nope, q_pe, ckv, kpe, return_lse=True)

wrapper.plan(
q_indptr,
kv_indptr,
Expand All @@ -382,7 +423,12 @@ def test_batch_mla_page_attention(
q_nope.dtype,
ckv.dtype,
)
o, lse = wrapper.run(q_nope, q_pe, ckv, kpe, return_lse=True)
if use_cuda_graph:
o.fill_(0)
lse.fill_(0)
g.replay()
else:
o, lse = wrapper.run(q_nope, q_pe, ckv, kpe, return_lse=True)

k, v = generate_kv_from_cache(ckv, kpe, kv_len, batch_size, num_heads)

Expand All @@ -408,3 +454,4 @@ def test_batch_mla_page_attention(
test_batch_mla_varlen_page_attention(
155, 1024, 8, 128, 128, 16, False, 1, "fa3", torch.half
)
test_batch_mla_page_attention(1, 1024, 128, 128, False, 1, "fa2", True, torch.half)