-
Notifications
You must be signed in to change notification settings - Fork 276
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Refactor] Uniform PoDAttention API with Horizontal Fusion SMs Schedule #967
base: main
Are you sure you want to change the base?
Conversation
Some of the unittests failed, for example (test_block_sparse_attention[False-256-16-16-128-64-16-4])
|
Hi, can I ask when this is planned to be merged? I made a PR to support POD Attn in SGLang using the old API and plan to get that working with CUDA graph first. |
I really like the uniform batch API that this PR presents. I ran this on an A100 and compared it with the existing FlashInfer POD-Attention implementation. On average this performed around 10 - 15% worse, but still better than serial execution. Performance was worse for larger prefill context lengths, while for smaller context lengths the performance was more comparable. |
Yeah this is more convenient, one issue i had during my PR is that I have to fill 2D attention mask for prefill every time, instead using page table & indices |
Will the old API be preserved? Thanks. |
@AKKamath Btw, I wonder what was the reason for using a mask instead of page table for prefill qkv? |
@yzh119 Can correct me here, but I believe the mask prefill kernel (single_prefill) had a better performance than the page table prefill because the page table prefill had a higher register usage causing register spills. |
Description
This PR is a follow-up to #858, which integrates the PoDAttention (arXiv link) API in a user-transparent manner. Users can now invoke PoDAttention via the same API as
BatchPrefillWithPagedKVCache
, without explicitly specifying whether requests are prefill or decode (example code).Key Changes
Support for Non-Continuous Q/O and KV Tensor Layout
Previously, tensor offsets were computed using
indptr
, assuming continuous layouts. PoDAttention requires supporting mixed prefill/decode subsets within requests, necessitating a non-continuous layout.q_lenptr
andkv_lenptr
to accommodate this functionality (code link).Horizontal Fusion-Style Implementation
For improved efficiency, subsets of requests are aware of each other, enabling optimal selection of kernel hyperparameters and persistent kernel execution.
Limitations and Future Work
qo_len > threshold
) is preliminary and requires improvement (classifier implementation).cc @AKKamath @yzh119