Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] MLA kernel fails the tests in tests/test_deepseek_mla.py #949

Closed
Atream opened this issue Mar 17, 2025 · 7 comments
Closed

[Bug] MLA kernel fails the tests in tests/test_deepseek_mla.py #949

Atream opened this issue Mar 17, 2025 · 7 comments

Comments

@Atream
Copy link
Contributor

Atream commented Mar 17, 2025

The MLA kernel fails the tests in tests/test_deepseek_mla.py. I used the current main branch with commit 27906fd, but it cannot pass the unit tests in tests/test_deepseek_mla.py. The output in the integrated system is also abnormal. After reverting to the previous commit 061db55, everything works fine.

Mismatched elements: 8318200 / 8388608 (99.2%)
Greatest absolute difference: 1.974609375 at index (0, 44, 501) (up to 0.001 allowed)
Greatest relative difference: inf at index (0, 18, 121) (up to 0.001 allowed)

Environment

RTX 4090, CUDA 12.4, torch 2.5.1
Fail in test_batch_mla_varlen_page_attention, test_batch_mla_varlen_page_attention, test_batch_mla_page_attention on BFloat16.
To test on 4090, I remove if not is_sm90a_supported(torch.device("cuda")) check.

@yzh119
Copy link
Collaborator

yzh119 commented Mar 17, 2025

Thanks for reporting the issue, I'll fix it soon and add mla unittests to CI.

@yzh119
Copy link
Collaborator

yzh119 commented Mar 17, 2025

Hi @Atream , I can't reproduce the issue, can you show me the exact test cast (the batch_size/kv_len/qo_len/etc in test_batch_mla_page_attention) that generates the wrong outputs:

Mismatched elements: 8318200 / 8388608 (99.2%)
Greatest absolute difference: 1.974609375 at index (0, 44, 501) (up to 0.001 allowed)
Greatest relative difference: inf at index (0, 18, 121) (up to 0.001 allowed)

To test on 4090, I remove if not is_sm90a_supported(torch.device("cuda")) check.

4090 do not support fa3 which relies on wgmma which is not available in sm90a (4090 has sm89), and you can try fa2 backend in this case.

@Atream
Copy link
Contributor Author

Atream commented Mar 17, 2025

Hi @Atream , I can't reproduce the issue, can you show me the exact test cast (the batch_size/kv_len/qo_len/etc in test_batch_mla_page_attention) that generates the wrong outputs:

Mismatched elements: 8318200 / 8388608 (99.2%)
Greatest absolute difference: 1.974609375 at index (0, 44, 501) (up to 0.001 allowed)
Greatest relative difference: inf at index (0, 18, 121) (up to 0.001 allowed)

To test on 4090, I remove if not is_sm90a_supported(torch.device("cuda")) check.

4090 do not support fa3 which relies on wgmma which is not available in sm90a (4090 has sm89), and you can try fa2 backend in this case.

I run this:

test_batch_mla_page_attention(1, 1024, 128, 128, False, 1, "fa2", True, torch.bfloat16)

@yzh119
Copy link
Collaborator

yzh119 commented Mar 17, 2025

Should have been fixed in #951 , you can check the unittest status at https://ci.tlcpack.ai/blue/organizations/jenkins/flashinfer-ci/detail/PR-951/2/pipeline (GPU-G5-Test-4).

@Atream
Copy link
Contributor Author

Atream commented Mar 17, 2025

I tested on my env.

test_batch_mla_page_attention(1, 1024, 128, 128, True, 1, "fa2", True, torch.bfloat16)

Mismatched elements: 33698 / 8388608 (0.4%)
Greatest absolute difference: 0.0078125 at index (0, 123, 162) (up to 0.001 allowed)
Greatest relative difference: 0.048828125 at index (30, 32, 363) (up to 0.001 allowed)
test_batch_mla_varlen_page_attention(1, 65, 65, 65, 1, 128, True, 64, "fa2", torch.bfloat16)
Mismatched elements: 7082 / 65536 (10.8%)
Greatest absolute difference: 0.015625 at index (0, 108, 276) (up to 0.001 allowed)
Greatest relative difference: 18.75 at index (0, 119, 158) (up to 0.001 allowed)

@yzh119
Copy link
Collaborator

yzh119 commented Mar 17, 2025

Hi @Atream that's desirable because the original atol and rtol are designed for fp16. bf16 have larger errors inherently (as studied in https://arxiv.org/abs/2405.02803), usually we can tolerate 2e-2 difference for bf16 unittests.

For bf16 unittests, we need to increase the atol and rtol correspondingly.

The end-to-end evaluation after #951 should be normal.

@Atream
Copy link
Contributor Author

Atream commented Mar 17, 2025

It works fine. Thank you for your quick fix.

@Atream Atream closed this as completed Mar 17, 2025
yzh119 added a commit that referenced this issue Mar 17, 2025
The sm86/sm89 version of mla kernel was not tests after change #942,
this PR fixes the issue.

This PR also make the following changes:
1. adding the mla unittest to CI (on a10g node). 
2. shrinking the unittest of mla so that CI can finish in reasonable
time.
3. change `is_sm90a_supported(torch.device("cuda"))` to `backend ==
"fa3" and not is_sm90a_supported(torch.device("cuda")):` for non-hopper
GPUs, as pointed by @Atream .
yyihuang pushed a commit to yyihuang/flashinfer that referenced this issue Mar 17, 2025
The sm86/sm89 version of mla kernel was not tests after change flashinfer-ai#942,
this PR fixes the issue.

This PR also make the following changes:
1. adding the mla unittest to CI (on a10g node). 
2. shrinking the unittest of mla so that CI can finish in reasonable
time.
3. change `is_sm90a_supported(torch.device("cuda"))` to `backend ==
"fa3" and not is_sm90a_supported(torch.device("cuda")):` for non-hopper
GPUs, as pointed by @Atream .
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants