diff --git a/include/flashinfer/attention/scheduler.cuh b/include/flashinfer/attention/scheduler.cuh index 7087101a5..44e8c0b40 100644 --- a/include/flashinfer/attention/scheduler.cuh +++ b/include/flashinfer/attention/scheduler.cuh @@ -1098,7 +1098,7 @@ inline cudaError_t MLAPlan(void* float_buffer, size_t float_workspace_size_in_by total_kv_lens += effective_kv_len; } } - int kv_len_limit = ceil_div(ceil_div(total_kv_lens, num_clusters), 512L) * 512L; + int kv_len_limit = ceil_div(std::max(ceil_div(total_kv_lens, num_clusters), 1L), 512L) * 512L; // step 1. load-balancing scheduling algorithm MinHeap cluster_cost_heap(num_clusters); diff --git a/tests/test_deepseek_mla.py b/tests/test_deepseek_mla.py index 9448828c2..e93aa1798 100644 --- a/tests/test_deepseek_mla.py +++ b/tests/test_deepseek_mla.py @@ -189,7 +189,7 @@ def generate_kv_from_cache(ckv, kpe, kv_len, batch_size, num_heads): @pytest.mark.parametrize("num_heads", [16, 32, 64]) @pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("page_size", [1]) -@pytest.mark.parametrize("backend", ["fa2", "fa3"]) +@pytest.mark.parametrize("backend", ["fa3"]) @pytest.mark.parametrize("dtype", [torch.half]) def test_batch_mla_varlen_page_attention( batch_size, @@ -311,15 +311,16 @@ def test_batch_mla_varlen_page_attention( # torch.testing.assert_close(lse_i, lse_ref, rtol=1e-3, atol=1e-3) -@pytest.mark.parametrize("batch_size", [1, 2, 3, 4, 5, 6, 7]) +@pytest.mark.parametrize("batch_size", [1, 2, 3, 4, 5, 6, 7, 157]) @pytest.mark.parametrize("kv_len", [0, 17, 33, 96, 97, 114, 514, 1024]) @pytest.mark.parametrize( "qo_len", [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17] ) @pytest.mark.parametrize("num_heads", [16]) @pytest.mark.parametrize("causal", [False, True]) -@pytest.mark.parametrize("page_size", [1]) +@pytest.mark.parametrize("page_size", [1, 16]) @pytest.mark.parametrize("backend", ["fa2", "fa3"]) +@pytest.mark.parametrize("use_cuda_graph", [True, False]) @pytest.mark.parametrize("dtype", [torch.half]) def test_batch_mla_page_attention( batch_size, @@ -329,6 +330,7 @@ def test_batch_mla_page_attention( causal, page_size, backend, + use_cuda_graph, dtype, ): if not mla_is_fa3_supported(torch.device("cuda")): @@ -362,12 +364,51 @@ def test_batch_mla_page_attention( sm_scale = 1.0 / ((128 + 64) ** 0.5) # use head dimension before matrix absorption workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0) wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( - workspace_buffer, backend=backend + workspace_buffer, + backend=backend, + use_cuda_graph=True, + qo_indptr=torch.empty(batch_size + 1, dtype=torch.int32, device="cuda"), + kv_indptr=torch.empty(batch_size + 1, dtype=torch.int32, device="cuda"), + kv_indices=torch.empty(1048576, dtype=torch.int32, device="cuda"), + kv_len_arr=torch.empty(batch_size, dtype=torch.int32, device="cuda"), ) q_indptr = torch.arange(0, batch_size + 1).to(0).int() * qo_len kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * pages_num kv_indices = torch.arange(0, batch_size * pages_num).to(0).int() kv_lens = torch.full((batch_size,), kv_len, dtype=torch.int32).to(0) + + if use_cuda_graph: + kv_indptr_warmup = torch.zeros(batch_size + 1).to(0).int() + kv_indices_warmup = torch.arange(0, batch_size).to(0).int() + kv_lens_warmup = torch.full((batch_size,), 0, dtype=torch.int32).to(0) + wrapper.plan( + q_indptr, + kv_indptr_warmup, + kv_indices_warmup, + kv_lens_warmup, + num_heads, + head_dim_ckv, + head_dim_kpe, + page_size, + causal, + sm_scale, + q_nope.dtype, + ckv.dtype, + ) + + # warmup + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + for _ in range(3): + o, lse = wrapper.run(q_nope, q_pe, ckv, kpe, return_lse=True) + torch.cuda.current_stream().wait_stream(s) + + # capture + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + o, lse = wrapper.run(q_nope, q_pe, ckv, kpe, return_lse=True) + wrapper.plan( q_indptr, kv_indptr, @@ -382,7 +423,12 @@ def test_batch_mla_page_attention( q_nope.dtype, ckv.dtype, ) - o, lse = wrapper.run(q_nope, q_pe, ckv, kpe, return_lse=True) + if use_cuda_graph: + o.fill_(0) + lse.fill_(0) + g.replay() + else: + o, lse = wrapper.run(q_nope, q_pe, ckv, kpe, return_lse=True) k, v = generate_kv_from_cache(ckv, kpe, kv_len, batch_size, num_heads) @@ -408,3 +454,4 @@ def test_batch_mla_page_attention( test_batch_mla_varlen_page_attention( 155, 1024, 8, 128, 128, 16, False, 1, "fa3", torch.half ) + test_batch_mla_page_attention(1, 1024, 128, 128, False, 1, "fa2", True, torch.half)