Skip to content

Commit 27906fd

Browse files
authored
fix - fix bug when not relevant seq has nan data (#942)
1 parent 061db55 commit 27906fd

File tree

2 files changed

+106
-22
lines changed

2 files changed

+106
-22
lines changed

include/flashinfer/attention/mla.cuh

+25-21
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,8 @@ __device__ __forceinline__ void load_kv(
181181
typename KTraits::SharedStorage* smem_storage, typename KTraits::DTypeKV* ckv,
182182
typename KTraits::DTypeKV* kpe, typename KTraits::IdType* indices, const uint32_t ckv_stride_n,
183183
const uint32_t ckv_stride_page, const uint32_t kpe_stride_n, const uint32_t kpe_stride_page,
184-
const uint32_t kv_bound, const uint32_t packed_block_iter_base, const uint_fastdiv& block_size,
185-
const uint32_t stage_idx) {
184+
const uint32_t packed_kv_bound, const uint32_t packed_block_iter_base,
185+
const uint_fastdiv& block_size, const uint32_t stage_idx) {
186186
using DTypeKV = typename KTraits::DTypeKV;
187187
constexpr uint32_t UPCAST_STRIDE_CKV = KTraits::UPCAST_STRIDE_CKV;
188188
constexpr uint32_t UPCAST_STRIDE_KPE = KTraits::UPCAST_STRIDE_KPE;
@@ -198,20 +198,23 @@ __device__ __forceinline__ void load_kv(
198198
if constexpr (KTraits::NUM_MMA_KV == 1) {
199199
if (warpgroup_idx == 0) {
200200
uint32_t q, r;
201+
uint32_t packed_block_iter =
202+
packed_block_iter_base + lane_idx / 8 + lane_idx / 8 + warp_idx_in_wg * 4;
203+
block_size.divmod(packed_block_iter, q, r);
201204

202-
block_size.divmod(packed_block_iter_base + lane_idx / 8 + warp_idx_in_wg * 4, q, r);
203-
204-
DTypeKV* ckv_ptr = ckv + (q < kv_bound ? indices[q] : 0) * ckv_stride_page +
205+
DTypeKV* ckv_ptr = ckv +
206+
(packed_block_iter < packed_kv_bound ? indices[q] : 0) * ckv_stride_page +
205207
r * ckv_stride_n + (lane_idx % 8) * upcast_size<DTypeKV>();
206-
DTypeKV* kpe_ptr = kpe + (q < kv_bound ? indices[q] : 0) * kpe_stride_page +
208+
DTypeKV* kpe_ptr = kpe +
209+
(packed_block_iter < packed_kv_bound ? indices[q] : 0) * kpe_stride_page +
207210
r * kpe_stride_n + (lane_idx % 8) * upcast_size<DTypeKV>();
208211

209212
#pragma unroll
210213
for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_CKV / 4; ++mma_d) {
211214
uint32_t ckv_smem_offset_w = ckv_smem.template get_permuted_offset<UPCAST_STRIDE_CKV>(
212215
warp_idx_in_wg * 4 + lane_idx / 8, 8 * mma_d + lane_idx % 8);
213216
ckv_smem.load_128b_async<SharedMemFillMode::kFillZero>(ckv_smem_offset_w, ckv_ptr,
214-
q < kv_bound);
217+
packed_block_iter < packed_kv_bound);
215218
ckv_ptr += 8 * upcast_size<DTypeKV>();
216219
}
217220

@@ -220,22 +223,23 @@ __device__ __forceinline__ void load_kv(
220223
uint32_t kpe_smem_offset_w = kpe_smem.template get_permuted_offset<UPCAST_STRIDE_KPE>(
221224
warp_idx_in_wg * 4 + lane_idx / 8, 8 * mma_d + lane_idx % 8);
222225
kpe_smem.load_128b_async<SharedMemFillMode::kFillZero>(kpe_smem_offset_w, kpe_ptr,
223-
q < kv_bound);
226+
packed_block_iter < packed_kv_bound);
224227
kpe_ptr += 8 * upcast_size<DTypeKV>();
225228
}
226229
}
227230
} else {
228231
#pragma unroll
229232
for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV / 2; ++mma_kv) {
230233
uint32_t q, r;
234+
uint32_t packed_block_iter = packed_block_iter_base + lane_idx / 8 +
235+
(warpgroup_idx + mma_kv * 2) * 16 + warp_idx_in_wg * 4;
236+
block_size.divmod(packed_block_iter, q, r);
231237

232-
block_size.divmod(packed_block_iter_base + lane_idx / 8 + (warpgroup_idx + mma_kv * 2) * 16 +
233-
warp_idx_in_wg * 4,
234-
q, r);
235-
236-
DTypeKV* ckv_ptr = ckv + (q < kv_bound ? indices[q] : 0) * ckv_stride_page +
238+
DTypeKV* ckv_ptr = ckv +
239+
(packed_block_iter < packed_kv_bound ? indices[q] : 0) * ckv_stride_page +
237240
r * ckv_stride_n + (lane_idx % 8) * upcast_size<DTypeKV>();
238-
DTypeKV* kpe_ptr = kpe + (q < kv_bound ? indices[q] : 0) * kpe_stride_page +
241+
DTypeKV* kpe_ptr = kpe +
242+
(packed_block_iter < packed_kv_bound ? indices[q] : 0) * kpe_stride_page +
239243
r * kpe_stride_n + (lane_idx % 8) * upcast_size<DTypeKV>();
240244

241245
#pragma unroll
@@ -244,7 +248,7 @@ __device__ __forceinline__ void load_kv(
244248
32 * mma_kv + warpgroup_idx * 16 + warp_idx_in_wg * 4 + lane_idx / 8,
245249
8 * mma_d + lane_idx % 8);
246250
ckv_smem.load_128b_async<SharedMemFillMode::kFillZero>(ckv_smem_offset_w, ckv_ptr,
247-
q < kv_bound);
251+
packed_block_iter < packed_kv_bound);
248252
ckv_ptr += 8 * upcast_size<DTypeKV>();
249253
}
250254

@@ -254,7 +258,7 @@ __device__ __forceinline__ void load_kv(
254258
32 * mma_kv + warpgroup_idx * 16 + warp_idx_in_wg * 4 + lane_idx / 8,
255259
8 * mma_d + lane_idx % 8);
256260
kpe_smem.load_128b_async<SharedMemFillMode::kFillZero>(kpe_smem_offset_w, kpe_ptr,
257-
q < kv_bound);
261+
packed_block_iter < packed_kv_bound);
258262
kpe_ptr += 8 * upcast_size<DTypeKV>();
259263
}
260264
}
@@ -863,17 +867,17 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPagedAttentionKe
863867
uint32_t block_iter_base = kv_indptr * block_size + kv_start;
864868
// last kv tile
865869
__syncthreads();
866-
uint32_t kv_bound = kv_indptr + (kv_len + block_size - 1) / block_size; // ceil_div
870+
uint32_t packed_kv_bound = kv_indptr * block_size + kv_len;
867871
load_kv<KTraits>(&smem_storage, ckv, kpe, kv_indices, ckv_stride_n, ckv_stride_page,
868-
kpe_stride_n, kpe_stride_page, kv_bound,
872+
kpe_stride_n, kpe_stride_page, packed_kv_bound,
869873
block_iter_base + kv_tile_idx * CTA_TILE_KV, block_size,
870874
kv_tile_idx % NUM_STAGES);
871875
cp_async::commit_group();
872876
#pragma unroll
873877
for (int stage_idx = 1; stage_idx < NUM_STAGES; ++stage_idx) {
874878
if (kv_tile_idx - stage_idx >= 0) {
875879
load_kv<KTraits>(&smem_storage, ckv, kpe, kv_indices, ckv_stride_n, ckv_stride_page,
876-
kpe_stride_n, kpe_stride_page, kv_bound,
880+
kpe_stride_n, kpe_stride_page, packed_kv_bound,
877881
block_iter_base + (kv_tile_idx - stage_idx) * CTA_TILE_KV, block_size,
878882
(kv_tile_idx - stage_idx) % NUM_STAGES);
879883
cp_async::commit_group();
@@ -903,7 +907,7 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPagedAttentionKe
903907
if (kv_tile_idx - NUM_STAGES >= 0) {
904908
__syncthreads();
905909
load_kv<KTraits>(&smem_storage, ckv, kpe, kv_indices, ckv_stride_n, ckv_stride_page,
906-
kpe_stride_n, kpe_stride_page, kv_bound,
910+
kpe_stride_n, kpe_stride_page, packed_kv_bound,
907911
block_iter_base + (kv_tile_idx - NUM_STAGES) * CTA_TILE_KV, block_size,
908912
(kv_tile_idx - NUM_STAGES) % NUM_STAGES);
909913
cp_async::commit_group();
@@ -927,7 +931,7 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPagedAttentionKe
927931

928932
__syncthreads();
929933
load_kv<KTraits>(&smem_storage, ckv, kpe, kv_indices, ckv_stride_n, ckv_stride_page,
930-
kpe_stride_n, kpe_stride_page, kv_bound,
934+
kpe_stride_n, kpe_stride_page, packed_kv_bound,
931935
block_iter_base + (kv_tile_idx - NUM_STAGES) * CTA_TILE_KV, block_size,
932936
(kv_tile_idx - NUM_STAGES) % NUM_STAGES);
933937
cp_async::commit_group();

tests/test_deepseek_mla.py

+81-1
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,86 @@ def test_batch_mla_varlen_page_attention(
325325
# torch.testing.assert_close(lse_i, lse_ref, rtol=1e-3, atol=1e-3)
326326

327327

328+
@pytest.mark.parametrize("batch_size", [1, 2, 3, 4, 5, 6, 7, 157])
329+
@pytest.mark.parametrize("kv_len", [17, 33, 75, 197])
330+
@pytest.mark.parametrize("qo_len", [3, 7, 17])
331+
@pytest.mark.parametrize("num_heads", [16])
332+
@pytest.mark.parametrize("causal", [False, True])
333+
@pytest.mark.parametrize("page_size", [16, 32])
334+
@pytest.mark.parametrize("backend", ["fa2", "fa3"])
335+
@pytest.mark.parametrize("dtype", [torch.half])
336+
def test_batch_mla_oob_kv_nan(
337+
batch_size, kv_len, qo_len, num_heads, causal, page_size, backend, dtype
338+
):
339+
if not is_sm90a_supported(torch.device("cuda")):
340+
pytest.skip("FA3 is not supported on this device")
341+
if causal and qo_len > kv_len:
342+
pytest.skip("qo_len > kv_len not supported for causal attention")
343+
torch.manual_seed(42)
344+
head_dim_ckv = 512
345+
head_dim_kpe = 64
346+
q_nope = torch.randn(
347+
batch_size * qo_len, num_heads, head_dim_ckv, dtype=dtype, device="cuda"
348+
)
349+
q_pe = torch.randn(
350+
batch_size * qo_len, num_heads, head_dim_kpe, dtype=dtype, device="cuda"
351+
)
352+
pages_num = math.ceil(kv_len / page_size)
353+
ckv = torch.randn(
354+
batch_size * pages_num, page_size, head_dim_ckv, dtype=dtype, device="cuda"
355+
)
356+
kpe = torch.randn(
357+
batch_size * pages_num, page_size, head_dim_kpe, dtype=dtype, device="cuda"
358+
)
359+
360+
# Fill oob positions with nan
361+
for i in range(batch_size):
362+
last_page_len = kv_len - (pages_num - 1) * page_size
363+
ckv[(i + 1) * pages_num - 1, last_page_len:, :] = float("nan")
364+
kpe[(i + 1) * pages_num - 1, last_page_len:, :] = float("nan")
365+
366+
sm_scale = 1.0 / ((128 + 64) ** 0.5)
367+
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda")
368+
wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(
369+
workspace_buffer, backend=backend
370+
)
371+
q_indptr = (
372+
torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * qo_len
373+
)
374+
kv_indptr = (
375+
torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * pages_num
376+
)
377+
kv_indices = torch.arange(
378+
0, batch_size * pages_num, device="cuda", dtype=torch.int32
379+
)
380+
kv_lens = torch.full((batch_size,), kv_len, dtype=torch.int32, device="cuda")
381+
382+
wrapper.plan(
383+
q_indptr,
384+
kv_indptr,
385+
kv_indices,
386+
kv_lens,
387+
num_heads,
388+
head_dim_ckv,
389+
head_dim_kpe,
390+
page_size,
391+
causal,
392+
sm_scale,
393+
q_nope.dtype,
394+
ckv.dtype,
395+
)
396+
o, lse = wrapper.run(q_nope, q_pe, ckv, kpe, return_lse=True)
397+
398+
k, v = generate_kv_from_cache(ckv, kpe, kv_len, batch_size, num_heads)
399+
400+
q = torch.cat([q_nope, q_pe], dim=-1)
401+
o_ref, lse_ref = attention_ref(batch_size, q, k, v, causal, sm_scale)
402+
lse_ref = lse_ref.flatten(0, 1)
403+
torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3)
404+
if kv_len != 0:
405+
torch.testing.assert_close(lse, lse_ref, rtol=1e-3, atol=1e-3)
406+
407+
328408
@pytest.mark.parametrize("batch_size", [1, 2, 3, 4, 5, 6, 7, 157])
329409
@pytest.mark.parametrize("kv_len", [0, 17, 33, 96, 97, 114, 514, 1024])
330410
@pytest.mark.parametrize(
@@ -471,7 +551,7 @@ def test_batch_mla_page_attention(
471551

472552
if __name__ == "__main__":
473553
test_batch_mla_varlen_page_attention(
474-
155, 64, 8, 8, 128, 16, False, 1, "fa3", torch.half
554+
1, 65, 65, 65, 1, 128, True, 64, "fa2", torch.half
475555
)
476556
# test_batch_mla_varlen_page_attention(
477557
# 155, 1024, 8, 128, 128, 16, False, 1, "fa3", torch.half

0 commit comments

Comments
 (0)