Skip to content

fix - fix bug when not relevant seq has nan data #942

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 25 additions & 21 deletions include/flashinfer/attention/mla.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,8 @@ __device__ __forceinline__ void load_kv(
typename KTraits::SharedStorage* smem_storage, typename KTraits::DTypeKV* ckv,
typename KTraits::DTypeKV* kpe, typename KTraits::IdType* indices, const uint32_t ckv_stride_n,
const uint32_t ckv_stride_page, const uint32_t kpe_stride_n, const uint32_t kpe_stride_page,
const uint32_t kv_bound, const uint32_t packed_block_iter_base, const uint_fastdiv& block_size,
const uint32_t stage_idx) {
const uint32_t packed_kv_bound, const uint32_t packed_block_iter_base,
const uint_fastdiv& block_size, const uint32_t stage_idx) {
using DTypeKV = typename KTraits::DTypeKV;
constexpr uint32_t UPCAST_STRIDE_CKV = KTraits::UPCAST_STRIDE_CKV;
constexpr uint32_t UPCAST_STRIDE_KPE = KTraits::UPCAST_STRIDE_KPE;
Expand All @@ -198,20 +198,23 @@ __device__ __forceinline__ void load_kv(
if constexpr (KTraits::NUM_MMA_KV == 1) {
if (warpgroup_idx == 0) {
uint32_t q, r;
uint32_t packed_block_iter =
packed_block_iter_base + lane_idx / 8 + lane_idx / 8 + warp_idx_in_wg * 4;
block_size.divmod(packed_block_iter, q, r);

block_size.divmod(packed_block_iter_base + lane_idx / 8 + warp_idx_in_wg * 4, q, r);

DTypeKV* ckv_ptr = ckv + (q < kv_bound ? indices[q] : 0) * ckv_stride_page +
DTypeKV* ckv_ptr = ckv +
(packed_block_iter < packed_kv_bound ? indices[q] : 0) * ckv_stride_page +
r * ckv_stride_n + (lane_idx % 8) * upcast_size<DTypeKV>();
DTypeKV* kpe_ptr = kpe + (q < kv_bound ? indices[q] : 0) * kpe_stride_page +
DTypeKV* kpe_ptr = kpe +
(packed_block_iter < packed_kv_bound ? indices[q] : 0) * kpe_stride_page +
r * kpe_stride_n + (lane_idx % 8) * upcast_size<DTypeKV>();

#pragma unroll
for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_CKV / 4; ++mma_d) {
uint32_t ckv_smem_offset_w = ckv_smem.template get_permuted_offset<UPCAST_STRIDE_CKV>(
warp_idx_in_wg * 4 + lane_idx / 8, 8 * mma_d + lane_idx % 8);
ckv_smem.load_128b_async<SharedMemFillMode::kFillZero>(ckv_smem_offset_w, ckv_ptr,
q < kv_bound);
packed_block_iter < packed_kv_bound);
ckv_ptr += 8 * upcast_size<DTypeKV>();
}

Expand All @@ -220,22 +223,23 @@ __device__ __forceinline__ void load_kv(
uint32_t kpe_smem_offset_w = kpe_smem.template get_permuted_offset<UPCAST_STRIDE_KPE>(
warp_idx_in_wg * 4 + lane_idx / 8, 8 * mma_d + lane_idx % 8);
kpe_smem.load_128b_async<SharedMemFillMode::kFillZero>(kpe_smem_offset_w, kpe_ptr,
q < kv_bound);
packed_block_iter < packed_kv_bound);
kpe_ptr += 8 * upcast_size<DTypeKV>();
}
}
} else {
#pragma unroll
for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV / 2; ++mma_kv) {
uint32_t q, r;
uint32_t packed_block_iter = packed_block_iter_base + lane_idx / 8 +
(warpgroup_idx + mma_kv * 2) * 16 + warp_idx_in_wg * 4;
block_size.divmod(packed_block_iter, q, r);

block_size.divmod(packed_block_iter_base + lane_idx / 8 + (warpgroup_idx + mma_kv * 2) * 16 +
warp_idx_in_wg * 4,
q, r);

DTypeKV* ckv_ptr = ckv + (q < kv_bound ? indices[q] : 0) * ckv_stride_page +
DTypeKV* ckv_ptr = ckv +
(packed_block_iter < packed_kv_bound ? indices[q] : 0) * ckv_stride_page +
r * ckv_stride_n + (lane_idx % 8) * upcast_size<DTypeKV>();
DTypeKV* kpe_ptr = kpe + (q < kv_bound ? indices[q] : 0) * kpe_stride_page +
DTypeKV* kpe_ptr = kpe +
(packed_block_iter < packed_kv_bound ? indices[q] : 0) * kpe_stride_page +
r * kpe_stride_n + (lane_idx % 8) * upcast_size<DTypeKV>();

#pragma unroll
Expand All @@ -244,7 +248,7 @@ __device__ __forceinline__ void load_kv(
32 * mma_kv + warpgroup_idx * 16 + warp_idx_in_wg * 4 + lane_idx / 8,
8 * mma_d + lane_idx % 8);
ckv_smem.load_128b_async<SharedMemFillMode::kFillZero>(ckv_smem_offset_w, ckv_ptr,
q < kv_bound);
packed_block_iter < packed_kv_bound);
ckv_ptr += 8 * upcast_size<DTypeKV>();
}

Expand All @@ -254,7 +258,7 @@ __device__ __forceinline__ void load_kv(
32 * mma_kv + warpgroup_idx * 16 + warp_idx_in_wg * 4 + lane_idx / 8,
8 * mma_d + lane_idx % 8);
kpe_smem.load_128b_async<SharedMemFillMode::kFillZero>(kpe_smem_offset_w, kpe_ptr,
q < kv_bound);
packed_block_iter < packed_kv_bound);
kpe_ptr += 8 * upcast_size<DTypeKV>();
}
}
Expand Down Expand Up @@ -863,17 +867,17 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPagedAttentionKe
uint32_t block_iter_base = kv_indptr * block_size + kv_start;
// last kv tile
__syncthreads();
uint32_t kv_bound = kv_indptr + (kv_len + block_size - 1) / block_size; // ceil_div
uint32_t packed_kv_bound = kv_indptr * block_size + kv_len;
load_kv<KTraits>(&smem_storage, ckv, kpe, kv_indices, ckv_stride_n, ckv_stride_page,
kpe_stride_n, kpe_stride_page, kv_bound,
kpe_stride_n, kpe_stride_page, packed_kv_bound,
block_iter_base + kv_tile_idx * CTA_TILE_KV, block_size,
kv_tile_idx % NUM_STAGES);
cp_async::commit_group();
#pragma unroll
for (int stage_idx = 1; stage_idx < NUM_STAGES; ++stage_idx) {
if (kv_tile_idx - stage_idx >= 0) {
load_kv<KTraits>(&smem_storage, ckv, kpe, kv_indices, ckv_stride_n, ckv_stride_page,
kpe_stride_n, kpe_stride_page, kv_bound,
kpe_stride_n, kpe_stride_page, packed_kv_bound,
block_iter_base + (kv_tile_idx - stage_idx) * CTA_TILE_KV, block_size,
(kv_tile_idx - stage_idx) % NUM_STAGES);
cp_async::commit_group();
Expand Down Expand Up @@ -903,7 +907,7 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPagedAttentionKe
if (kv_tile_idx - NUM_STAGES >= 0) {
__syncthreads();
load_kv<KTraits>(&smem_storage, ckv, kpe, kv_indices, ckv_stride_n, ckv_stride_page,
kpe_stride_n, kpe_stride_page, kv_bound,
kpe_stride_n, kpe_stride_page, packed_kv_bound,
block_iter_base + (kv_tile_idx - NUM_STAGES) * CTA_TILE_KV, block_size,
(kv_tile_idx - NUM_STAGES) % NUM_STAGES);
cp_async::commit_group();
Expand All @@ -927,7 +931,7 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPagedAttentionKe

__syncthreads();
load_kv<KTraits>(&smem_storage, ckv, kpe, kv_indices, ckv_stride_n, ckv_stride_page,
kpe_stride_n, kpe_stride_page, kv_bound,
kpe_stride_n, kpe_stride_page, packed_kv_bound,
block_iter_base + (kv_tile_idx - NUM_STAGES) * CTA_TILE_KV, block_size,
(kv_tile_idx - NUM_STAGES) % NUM_STAGES);
cp_async::commit_group();
Expand Down
82 changes: 81 additions & 1 deletion tests/test_deepseek_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,86 @@ def test_batch_mla_varlen_page_attention(
# torch.testing.assert_close(lse_i, lse_ref, rtol=1e-3, atol=1e-3)


@pytest.mark.parametrize("batch_size", [1, 2, 3, 4, 5, 6, 7, 157])
@pytest.mark.parametrize("kv_len", [17, 33, 75, 197])
@pytest.mark.parametrize("qo_len", [3, 7, 17])
@pytest.mark.parametrize("num_heads", [16])
@pytest.mark.parametrize("causal", [False, True])
@pytest.mark.parametrize("page_size", [16, 32])
@pytest.mark.parametrize("backend", ["fa2", "fa3"])
@pytest.mark.parametrize("dtype", [torch.half])
def test_batch_mla_oob_kv_nan(
batch_size, kv_len, qo_len, num_heads, causal, page_size, backend, dtype
):
if not is_sm90a_supported(torch.device("cuda")):
pytest.skip("FA3 is not supported on this device")
if causal and qo_len > kv_len:
pytest.skip("qo_len > kv_len not supported for causal attention")
torch.manual_seed(42)
head_dim_ckv = 512
head_dim_kpe = 64
q_nope = torch.randn(
batch_size * qo_len, num_heads, head_dim_ckv, dtype=dtype, device="cuda"
)
q_pe = torch.randn(
batch_size * qo_len, num_heads, head_dim_kpe, dtype=dtype, device="cuda"
)
pages_num = math.ceil(kv_len / page_size)
ckv = torch.randn(
batch_size * pages_num, page_size, head_dim_ckv, dtype=dtype, device="cuda"
)
kpe = torch.randn(
batch_size * pages_num, page_size, head_dim_kpe, dtype=dtype, device="cuda"
)

# Fill oob positions with nan
for i in range(batch_size):
last_page_len = kv_len - (pages_num - 1) * page_size
ckv[(i + 1) * pages_num - 1, last_page_len:, :] = float("nan")
kpe[(i + 1) * pages_num - 1, last_page_len:, :] = float("nan")

sm_scale = 1.0 / ((128 + 64) ** 0.5)
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda")
wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(
workspace_buffer, backend=backend
)
q_indptr = (
torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * qo_len
)
kv_indptr = (
torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * pages_num
)
kv_indices = torch.arange(
0, batch_size * pages_num, device="cuda", dtype=torch.int32
)
kv_lens = torch.full((batch_size,), kv_len, dtype=torch.int32, device="cuda")

wrapper.plan(
q_indptr,
kv_indptr,
kv_indices,
kv_lens,
num_heads,
head_dim_ckv,
head_dim_kpe,
page_size,
causal,
sm_scale,
q_nope.dtype,
ckv.dtype,
)
o, lse = wrapper.run(q_nope, q_pe, ckv, kpe, return_lse=True)

k, v = generate_kv_from_cache(ckv, kpe, kv_len, batch_size, num_heads)

q = torch.cat([q_nope, q_pe], dim=-1)
o_ref, lse_ref = attention_ref(batch_size, q, k, v, causal, sm_scale)
lse_ref = lse_ref.flatten(0, 1)
torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3)
if kv_len != 0:
torch.testing.assert_close(lse, lse_ref, rtol=1e-3, atol=1e-3)


@pytest.mark.parametrize("batch_size", [1, 2, 3, 4, 5, 6, 7, 157])
@pytest.mark.parametrize("kv_len", [0, 17, 33, 96, 97, 114, 514, 1024])
@pytest.mark.parametrize(
Expand Down Expand Up @@ -471,7 +551,7 @@ def test_batch_mla_page_attention(

if __name__ == "__main__":
test_batch_mla_varlen_page_attention(
155, 64, 8, 8, 128, 16, False, 1, "fa3", torch.half
1, 65, 65, 65, 1, 128, True, 64, "fa2", torch.half
)
# test_batch_mla_varlen_page_attention(
# 155, 1024, 8, 128, 128, 16, False, 1, "fa3", torch.half
Expand Down