Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix - fix bug when not relevant seq has nan data #942

Merged
merged 2 commits into from
Mar 15, 2025

Conversation

baowendin
Copy link
Contributor

We found that in the batch MLA FA2 implementation, if there exists NaN data within the computed CKV cache page, the kernel implementation would inadvertently read these invalid values and corrupt other valid data, ultimately resulting in all outputs becoming NaN. Therefore, this PR introduces masking logic to exclude CKV entries outside the computation scope from being processed. It may cause other requests to get error response when previous request precision overflow
and write nan to kvcache.

How to reproduce

  1. modify test_batch_mla_varlen_page_attention in test_deepseek_mla.py to use fa2 implement, and set an nan data in unrelated sequence
    # disable fa3 check
    # if not is_sm90a_supported(torch.device("cuda")):
    #     pytest.skip("FA3 is not supported on this device")
    if causal and qo_len > min(kv_len_0, kv_len_1, kv_len_2):
        pytest.skip("qo_len > kv_len not supported for causal attention")
    num_different_kv_len = 3
    .....
    kpe = torch.randn(
        batch_size * pages_nums_sum,
        page_size,
        head_dim_kpe,
        dtype=dtype,
        device="cuda",
    )
    # set a not related seq to nan
    ckv[1][10] = float("nan")
  1. use test in put as below

test_deepseek_mla.py

if __name__ == "__main__":    
    test_batch_mla_varlen_page_attention(
        1, 65, 65, 65, 1, 128, True, 64, "fa2", torch.half
    )
  1. nan output can be observed in output
2025-03-14 10:06:40,470 - INFO - flashinfer.jit: Finished loading JIT ops: batch_mla_attention_dtype_q_f16_dtype_kv_f16_dtype_o_f16_dtype_idx_i32_head_dim_ckv_512_head_dim_kpe_64_profiler_False
tensor([[[    nan,     nan,     nan,  ...,     nan,     nan,     nan],
         [    nan,     nan,     nan,  ...,     nan,     nan,     nan],
         [    nan,     nan,     nan,  ...,     nan,     nan,     nan],
         ...,
         [    nan,     nan,     nan,  ...,     nan,     nan,     nan],
         [    nan,     nan,     nan,  ...,     nan,     nan,     nan],
         [    nan,     nan,     nan,  ...,     nan,     nan,     nan]],

        [[ 0.2800, -0.1050,  0.0584,  ..., -0.3953,  0.2433,  0.0251],
         [ 0.2006,  0.2817, -0.1516,  ...,  0.2883,  0.2686, -0.1577],
         [ 0.1216, -0.8970, -0.1066,  ...,  0.5020, -0.7349,  0.3784],
         ...,
         [ 0.0591,  0.5034, -0.3750,  ..., -0.2786, -0.2483,  0.0037],
         [ 0.2854,  0.1653, -0.1549,  ..., -0.5571, -0.3015,  0.0625],
         [ 0.0942, -0.3574,  0.1144,  ...,  0.1730, -0.2646,  0.2479]],

        [[ 0.0208, -0.3357,  0.2020,  ..., -0.1625, -0.0917,  0.2349],
         [ 0.0799,  0.6021, -0.0385,  ...,  0.5962,  0.3550,  0.2208],
         [ 0.3367, -0.5327, -0.2496,  ...,  0.0653, -0.2144,  0.1873],
         ...,
         [ 0.0577, -1.1670,  1.5000,  ...,  1.4023,  0.7612, -0.6890],
         [-0.2124, -0.2549,  0.3342,  ..., -0.5610, -0.1438, -0.1825],
         [ 0.0499, -0.3467, -0.1461,  ..., -0.0058,  0.0474,  0.3188]]],
       device='cuda:0', dtype=torch.float16)

q < kv_bound);
ckv_smem.load_128b_async<SharedMemFillMode::kFillZero>(
ckv_smem_offset_w, ckv_ptr,
(q < kv_bound - 1) || ((q == kv_bound - 1) && (r <= last_page_kv_idx)));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there might be a easier way:

  1. passing kv_end instead of last_page_kv_idx to this function.
  2. change the condition to (packed_block_iter_base + lane_idx / 8 + warp_idx_in_wg * 4) < kv_end

WDYT?

Copy link
Collaborator

@yzh119 yzh119 Mar 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I created anther bugfix (#945) for FA3 template, does that look good to you? If so, I suppose we can use similar form here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay, let me modify the commit

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry I directly modify the commit and merge it, let me know if you think there any issue with it

@yzh119 yzh119 force-pushed the fix/fix_mla_fa2_nan branch from cf8b820 to dfbcc0c Compare March 15, 2025 02:13
@yzh119 yzh119 merged commit 27906fd into flashinfer-ai:main Mar 15, 2025
2 checks passed
@yzh119 yzh119 mentioned this pull request Mar 17, 2025
yzh119 added a commit that referenced this pull request Mar 17, 2025
The sm86/sm89 version of mla kernel was not tests after change #942,
this PR fixes the issue.

This PR also make the following changes:
1. adding the mla unittest to CI (on a10g node). 
2. shrinking the unittest of mla so that CI can finish in reasonable
time.
3. change `is_sm90a_supported(torch.device("cuda"))` to `backend ==
"fa3" and not is_sm90a_supported(torch.device("cuda")):` for non-hopper
GPUs, as pointed by @Atream .
yyihuang pushed a commit to yyihuang/flashinfer that referenced this pull request Mar 17, 2025
The sm86/sm89 version of mla kernel was not tests after change flashinfer-ai#942,
this PR fixes the issue.

This PR also make the following changes:
1. adding the mla unittest to CI (on a10g node). 
2. shrinking the unittest of mla so that CI can finish in reasonable
time.
3. change `is_sm90a_supported(torch.device("cuda"))` to `backend ==
"fa3" and not is_sm90a_supported(torch.device("cuda")):` for non-hopper
GPUs, as pointed by @Atream .
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants