Skip to content

Commit 30b2838

Browse files
authored
bugfix: bugfix to #949 (#951)
The sm86/sm89 version of mla kernel was not tests after change #942, this PR fixes the issue. This PR also make the following changes: 1. adding the mla unittest to CI (on a10g node). 2. shrinking the unittest of mla so that CI can finish in reasonable time. 3. change `is_sm90a_supported(torch.device("cuda"))` to `backend == "fa3" and not is_sm90a_supported(torch.device("cuda")):` for non-hopper GPUs, as pointed by @Atream .
1 parent 211dfc6 commit 30b2838

File tree

6 files changed

+145
-39
lines changed

6 files changed

+145
-39
lines changed

Jenkinsfile

+10
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,16 @@ stage('JIT Unittest') {
106106
sh(script: "${docker_run} ./scripts/task_jit_run_tests_part2.sh", label: 'JIT Unittest Part 2')
107107
}
108108
}
109+
},
110+
'GPU-G5-Test-4': {
111+
node('GPU-G5-SPOT') {
112+
ws(per_exec_ws('flashinfer-unittest')) {
113+
init_git(true) // we need cutlass submodule
114+
sh(script: "ls -alh", label: 'Show work directory')
115+
sh(script: "./scripts/task_show_node_info.sh", label: 'Show node info')
116+
sh(script: "${docker_run} ./scripts/task_jit_run_tests_part4.sh", label: 'JIT Unittest Part 4')
117+
}
118+
}
109119
}
110120
)
111121
}

include/flashinfer/attention/mla.cuh

+1-2
Original file line numberDiff line numberDiff line change
@@ -198,8 +198,7 @@ __device__ __forceinline__ void load_kv(
198198
if constexpr (KTraits::NUM_MMA_KV == 1) {
199199
if (warpgroup_idx == 0) {
200200
uint32_t q, r;
201-
uint32_t packed_block_iter =
202-
packed_block_iter_base + lane_idx / 8 + lane_idx / 8 + warp_idx_in_wg * 4;
201+
uint32_t packed_block_iter = packed_block_iter_base + lane_idx / 8 + warp_idx_in_wg * 4;
203202
block_size.divmod(packed_block_iter, q, r);
204203

205204
DTypeKV* ckv_ptr = ckv +

include/flashinfer/attention/prefill.cuh

+21-15
Original file line numberDiff line numberDiff line change
@@ -1504,19 +1504,21 @@ cudaError_t SinglePrefillWithKVCacheDispatched(Params params, typename Params::D
15041504
FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(
15051505
&max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, dev_id));
15061506
// we expect each sm execute two threadblocks
1507-
// TODO(Zihao): fix the following computation
1508-
const int num_ctas_per_sm = max_smem_per_sm > (16 * HEAD_DIM_QK * sizeof(DTypeQ) * 16) ? 2 : 1;
1507+
const int num_ctas_per_sm =
1508+
max_smem_per_sm >= 2 * (CTA_TILE_Q * HEAD_DIM_QK * sizeof(DTypeQ) +
1509+
(HEAD_DIM_QK + HEAD_DIM_VO) * 16 * NUM_WARPS_KV * sizeof(DTypeKV))
1510+
? 2
1511+
: 1;
15091512
const int max_smem_per_threadblock = max_smem_per_sm / num_ctas_per_sm;
15101513

15111514
const uint32_t max_num_mma_kv_reg =
15121515
(HEAD_DIM_VO >= 128 && NUM_MMA_Q == 2 && POS_ENCODING_MODE == PosEncodingMode::kRoPELlama &&
15131516
!USE_FP16_QK_REDUCTION)
15141517
? 2
15151518
: (8 / NUM_MMA_Q);
1516-
// TODO(Zihao): fix the following computation
15171519
const uint32_t max_num_mma_kv_smem =
1518-
(max_smem_per_threadblock / (16 * HEAD_DIM_QK * sizeof(DTypeQ)) - NUM_MMA_Q * NUM_WARPS_Q) /
1519-
(2 * NUM_WARPS_KV);
1520+
(max_smem_per_threadblock - CTA_TILE_Q * HEAD_DIM_QK * sizeof(DTypeQ)) /
1521+
((HEAD_DIM_QK + HEAD_DIM_VO) * 16 * NUM_WARPS_KV * sizeof(DTypeKV));
15201522

15211523
// control NUM_MMA_KV for maximum warp occupancy
15221524
DISPATCH_NUM_MMA_KV(min(max_num_mma_kv_smem, max_num_mma_kv_reg), NUM_MMA_KV, {
@@ -2223,19 +2225,21 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched(Params params, typename Para
22232225
FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&max_smem_per_sm,
22242226
cudaDevAttrMaxSharedMemoryPerMultiprocessor, dev_id));
22252227
// we expect each sm execute two threadblocks
2226-
// TODO(Zihao): fix the following computation
2227-
const int num_ctas_per_sm = max_smem_per_sm > (16 * HEAD_DIM_QK * sizeof(DTypeQ) * 16) ? 2 : 1;
2228+
const int num_ctas_per_sm =
2229+
max_smem_per_sm >= 2 * (CTA_TILE_Q * HEAD_DIM_QK * sizeof(DTypeQ) +
2230+
(HEAD_DIM_QK + HEAD_DIM_VO) * 16 * NUM_WARPS_KV * sizeof(DTypeKV))
2231+
? 2
2232+
: 1;
22282233
const int max_smem_per_threadblock = max_smem_per_sm / num_ctas_per_sm;
22292234

22302235
const uint32_t max_num_mma_kv_reg =
22312236
(HEAD_DIM_VO >= 128 && NUM_MMA_Q == 2 && POS_ENCODING_MODE == PosEncodingMode::kRoPELlama &&
22322237
!USE_FP16_QK_REDUCTION)
22332238
? 2
22342239
: (8 / NUM_MMA_Q);
2235-
// TODO(Zihao): fix the following computation
22362240
const uint32_t max_num_mma_kv_smem =
2237-
(max_smem_per_threadblock / (16 * HEAD_DIM_QK * sizeof(DTypeQ)) - NUM_MMA_Q * NUM_WARPS_Q) /
2238-
(2 * NUM_WARPS_KV);
2241+
(max_smem_per_threadblock - CTA_TILE_Q * HEAD_DIM_QK * sizeof(DTypeQ)) /
2242+
((HEAD_DIM_QK + HEAD_DIM_VO) * 16 * NUM_WARPS_KV * sizeof(DTypeKV));
22392243

22402244
DISPATCH_NUM_MMA_KV(min(max_num_mma_kv_smem, max_num_mma_kv_reg), NUM_MMA_KV, {
22412245
using KTraits =
@@ -2324,19 +2328,21 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched(Params params, typename Param
23242328
FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&max_smem_per_sm,
23252329
cudaDevAttrMaxSharedMemoryPerMultiprocessor, dev_id));
23262330
// we expect each sm execute two threadblocks
2327-
// TODO(Zihao): fix the following computation
2328-
const int num_ctas_per_sm = max_smem_per_sm > (16 * HEAD_DIM_QK * sizeof(DTypeQ) * 16) ? 2 : 1;
2331+
const int num_ctas_per_sm =
2332+
max_smem_per_sm >= 2 * (CTA_TILE_Q * HEAD_DIM_QK * sizeof(DTypeQ) +
2333+
(HEAD_DIM_QK + HEAD_DIM_VO) * 16 * NUM_WARPS_KV * sizeof(DTypeKV))
2334+
? 2
2335+
: 1;
23292336
const int max_smem_per_threadblock = max_smem_per_sm / num_ctas_per_sm;
23302337

23312338
const uint32_t max_num_mma_kv_reg =
23322339
(HEAD_DIM_VO >= 128 && NUM_MMA_Q == 2 && POS_ENCODING_MODE == PosEncodingMode::kRoPELlama &&
23332340
!USE_FP16_QK_REDUCTION)
23342341
? 2
23352342
: (8 / NUM_MMA_Q);
2336-
// TODO(Zihao): fix the following computation
23372343
const uint32_t max_num_mma_kv_smem =
2338-
(max_smem_per_threadblock / (16 * HEAD_DIM_QK * sizeof(DTypeQ)) - NUM_MMA_Q * NUM_WARPS_Q) /
2339-
(2 * NUM_WARPS_KV);
2344+
(max_smem_per_threadblock - CTA_TILE_Q * HEAD_DIM_QK * sizeof(DTypeQ)) /
2345+
((HEAD_DIM_QK + HEAD_DIM_VO) * 16 * NUM_WARPS_KV * sizeof(DTypeKV));
23402346

23412347
DISPATCH_NUM_MMA_KV(min(max_num_mma_kv_smem, max_num_mma_kv_reg), NUM_MMA_KV, {
23422348
using KTraits =

scripts/task_jit_run_tests_part4.sh

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#!/bin/bash
2+
3+
set -eo pipefail
4+
set -x
5+
: ${MAX_JOBS:=$(nproc)}
6+
: ${CUDA_VISIBLE_DEVICES:=0}
7+
8+
pip install -e . -v
9+
10+
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True # avoid memory fragmentation
11+
pytest -s tests/test_deepseek_mla.py

tests/conftest.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -121,15 +121,17 @@ def pytest_configure(config):
121121
_monkeypatch_add_torch_compile(fn)
122122

123123

124+
def is_cuda_oom_error_str(e: str) -> bool:
125+
return "CUDA" in e and "out of memory" in e
126+
127+
124128
@pytest.hookimpl(tryfirst=True)
125129
def pytest_runtest_call(item):
126130
# skip OOM error
127131
try:
128132
item.runtest()
129-
except (torch.OutOfMemoryError, RuntimeError) as e:
130-
if isinstance(e, torch.OutOfMemoryError) or "CUDA error: out of memory" in str(
131-
e
132-
):
133+
except (torch.cuda.OutOfMemoryError, RuntimeError) as e:
134+
if isinstance(e, torch.cuda.OutOfMemoryError) or is_cuda_oom_error_str(str(e)):
133135
pytest.skip("Skipping due to OOM")
134136
else:
135137
raise

tests/test_deepseek_mla.py

+96-18
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,91 @@
2020
import torch
2121

2222
import flashinfer
23+
from flashinfer.jit.attention import (
24+
gen_batch_mla_module,
25+
gen_batch_prefill_module,
26+
gen_single_prefill_module,
27+
)
2328
from flashinfer.utils import is_sm90a_supported
2429

2530

31+
@pytest.fixture(autouse=True, scope="module")
32+
def warmup_jit():
33+
try:
34+
modules = []
35+
for backend in ["fa2", "fa3"]:
36+
if backend == "fa3" and not is_sm90a_supported(torch.device("cuda")):
37+
continue
38+
39+
modules.append(
40+
(
41+
gen_single_prefill_module,
42+
[
43+
backend,
44+
torch.float16,
45+
torch.float16,
46+
torch.float16,
47+
192,
48+
128,
49+
0,
50+
False,
51+
False,
52+
False,
53+
],
54+
)
55+
)
56+
57+
for backend in ["fa2", "fa3"]:
58+
if backend == "fa3" and not is_sm90a_supported(torch.device("cuda")):
59+
continue
60+
61+
modules.append(
62+
(
63+
gen_batch_prefill_module,
64+
[
65+
backend,
66+
torch.float16,
67+
torch.float16,
68+
torch.float16,
69+
torch.int32,
70+
192,
71+
128,
72+
0,
73+
False,
74+
False,
75+
False,
76+
],
77+
)
78+
)
79+
80+
for backend in ["fa2", "fa3"]:
81+
if backend == "fa3" and not is_sm90a_supported(torch.device("cuda")):
82+
continue
83+
84+
modules.append(
85+
(
86+
gen_batch_mla_module,
87+
[
88+
backend,
89+
torch.float16,
90+
torch.float16,
91+
torch.float16,
92+
torch.int32,
93+
512,
94+
64,
95+
False,
96+
],
97+
)
98+
)
99+
100+
flashinfer.jit.parallel_load_modules(modules)
101+
except Exception as e:
102+
# abort the test session if warmup fails
103+
pytest.exit(str(e))
104+
finally:
105+
yield
106+
107+
26108
def attention_ref(
27109
batch_size,
28110
q: torch.Tensor,
@@ -83,7 +165,7 @@ def test_single_prefill_with_kv_cache(
83165
backend,
84166
dtype,
85167
):
86-
if not is_sm90a_supported(torch.device("cuda")):
168+
if backend == "fa3" and not is_sm90a_supported(torch.device("cuda")):
87169
pytest.skip("FA3 is not supported on this device")
88170
torch.manual_seed(42)
89171
head_dim_qk = 192
@@ -117,7 +199,7 @@ def test_batch_prefill_with_ragged_kv_cache(
117199
backend,
118200
dtype,
119201
):
120-
if not is_sm90a_supported(torch.device("cuda")):
202+
if backend == "fa3" and not is_sm90a_supported(torch.device("cuda")):
121203
pytest.skip("FA3 is not supported on this device")
122204
torch.manual_seed(42)
123205
kv_layout = "NHD"
@@ -188,17 +270,15 @@ def generate_kv_from_cache(ckv, kpe, kv_len, batch_size, num_heads):
188270
return k, v
189271

190272

191-
@pytest.mark.parametrize("batch_size", [1, 2, 3, 4, 5, 6, 7])
192-
@pytest.mark.parametrize("kv_len_0", [0, 1, 2, 3, 4, 11])
193-
@pytest.mark.parametrize("kv_len_1", [17, 19, 33, 79, 114])
273+
@pytest.mark.parametrize("batch_size", [1, 3, 5, 7])
274+
@pytest.mark.parametrize("kv_len_0", [0, 1, 3, 11])
275+
@pytest.mark.parametrize("kv_len_1", [17, 33, 79, 114])
194276
@pytest.mark.parametrize("kv_len_2", [514, 2743, 8736])
195-
@pytest.mark.parametrize(
196-
"qo_len", [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17]
197-
)
198-
@pytest.mark.parametrize("num_heads", [16, 32, 64])
277+
@pytest.mark.parametrize("qo_len", [1, 3, 5, 7, 9, 11, 13, 15, 17])
278+
@pytest.mark.parametrize("num_heads", [16, 64])
199279
@pytest.mark.parametrize("causal", [False, True])
200280
@pytest.mark.parametrize("page_size", [1])
201-
@pytest.mark.parametrize("backend", ["fa3"])
281+
@pytest.mark.parametrize("backend", ["fa2", "fa3"])
202282
@pytest.mark.parametrize("dtype", [torch.half])
203283
def test_batch_mla_varlen_page_attention(
204284
batch_size,
@@ -212,7 +292,7 @@ def test_batch_mla_varlen_page_attention(
212292
backend,
213293
dtype,
214294
):
215-
if not is_sm90a_supported(torch.device("cuda")):
295+
if backend == "fa3" and not is_sm90a_supported(torch.device("cuda")):
216296
pytest.skip("FA3 is not supported on this device")
217297
if causal and qo_len > min(kv_len_0, kv_len_1, kv_len_2):
218298
pytest.skip("qo_len > kv_len not supported for causal attention")
@@ -336,7 +416,7 @@ def test_batch_mla_varlen_page_attention(
336416
def test_batch_mla_oob_kv_nan(
337417
batch_size, kv_len, qo_len, num_heads, causal, page_size, backend, dtype
338418
):
339-
if not is_sm90a_supported(torch.device("cuda")):
419+
if backend == "fa3" and not is_sm90a_supported(torch.device("cuda")):
340420
pytest.skip("FA3 is not supported on this device")
341421
if causal and qo_len > kv_len:
342422
pytest.skip("qo_len > kv_len not supported for causal attention")
@@ -405,16 +485,14 @@ def test_batch_mla_oob_kv_nan(
405485
torch.testing.assert_close(lse, lse_ref, rtol=1e-3, atol=1e-3)
406486

407487

408-
@pytest.mark.parametrize("batch_size", [1, 2, 3, 4, 5, 6, 7, 157])
488+
@pytest.mark.parametrize("batch_size", [1, 3, 5, 7, 157])
409489
@pytest.mark.parametrize("kv_len", [0, 17, 33, 96, 97, 114, 514, 1024])
410-
@pytest.mark.parametrize(
411-
"qo_len", [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17]
412-
)
490+
@pytest.mark.parametrize("qo_len", [1, 3, 5, 7, 9, 11, 13, 15, 17])
413491
@pytest.mark.parametrize("num_heads", [16])
414492
@pytest.mark.parametrize("causal", [False, True])
415493
@pytest.mark.parametrize("page_size", [1, 16])
416494
@pytest.mark.parametrize("backend", ["fa2", "fa3"])
417-
@pytest.mark.parametrize("use_cuda_graph", [True, False])
495+
@pytest.mark.parametrize("use_cuda_graph", [False])
418496
@pytest.mark.parametrize("dtype", [torch.half])
419497
def test_batch_mla_page_attention(
420498
batch_size,
@@ -427,7 +505,7 @@ def test_batch_mla_page_attention(
427505
use_cuda_graph,
428506
dtype,
429507
):
430-
if not is_sm90a_supported(torch.device("cuda")):
508+
if backend == "fa3" and not is_sm90a_supported(torch.device("cuda")):
431509
pytest.skip("FA3 is not supported on this device")
432510
if causal and qo_len > kv_len:
433511
pytest.skip("qo_len > kv_len not supported for causal attention")

0 commit comments

Comments
 (0)