-
Notifications
You must be signed in to change notification settings - Fork 276
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix - fix bug when not relevant seq has nan data #942
Conversation
979128f
to
cf8b820
Compare
include/flashinfer/attention/mla.cuh
Outdated
q < kv_bound); | ||
ckv_smem.load_128b_async<SharedMemFillMode::kFillZero>( | ||
ckv_smem_offset_w, ckv_ptr, | ||
(q < kv_bound - 1) || ((q == kv_bound - 1) && (r <= last_page_kv_idx))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there might be a easier way:
- passing
kv_end
instead oflast_page_kv_idx
to this function. - change the condition to
(packed_block_iter_base + lane_idx / 8 + warp_idx_in_wg * 4) < kv_end
WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I created anther bugfix (#945) for FA3 template, does that look good to you? If so, I suppose we can use similar form here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okay, let me modify the commit
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry I directly modify the commit and merge it, let me know if you think there any issue with it
cf8b820
to
dfbcc0c
Compare
The sm86/sm89 version of mla kernel was not tests after change #942, this PR fixes the issue. This PR also make the following changes: 1. adding the mla unittest to CI (on a10g node). 2. shrinking the unittest of mla so that CI can finish in reasonable time. 3. change `is_sm90a_supported(torch.device("cuda"))` to `backend == "fa3" and not is_sm90a_supported(torch.device("cuda")):` for non-hopper GPUs, as pointed by @Atream .
The sm86/sm89 version of mla kernel was not tests after change flashinfer-ai#942, this PR fixes the issue. This PR also make the following changes: 1. adding the mla unittest to CI (on a10g node). 2. shrinking the unittest of mla so that CI can finish in reasonable time. 3. change `is_sm90a_supported(torch.device("cuda"))` to `backend == "fa3" and not is_sm90a_supported(torch.device("cuda")):` for non-hopper GPUs, as pointed by @Atream .
We found that in the batch MLA FA2 implementation, if there exists NaN data within the computed CKV cache page, the kernel implementation would inadvertently read these invalid values and corrupt other valid data, ultimately resulting in all outputs becoming NaN. Therefore, this PR introduces masking logic to exclude CKV entries outside the computation scope from being processed. It may cause other requests to get error response when previous request precision overflow
and write nan to kvcache.
How to reproduce
test_batch_mla_varlen_page_attention
intest_deepseek_mla.py
to use fa2 implement, and set an nan data in unrelated sequencetest_deepseek_mla.py