Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Uniform PoDAttention API with Horizontal Fusion SMs Schedule #967

Open
wants to merge 12 commits into
base: main
Choose a base branch
from

Conversation

happierpig
Copy link
Collaborator

Description

This PR is a follow-up to #858, which integrates the PoDAttention (arXiv link) API in a user-transparent manner. Users can now invoke PoDAttention via the same API as BatchPrefillWithPagedKVCache, without explicitly specifying whether requests are prefill or decode (example code).

Key Changes

  1. Support for Non-Continuous Q/O and KV Tensor Layout
    Previously, tensor offsets were computed using indptr, assuming continuous layouts. PoDAttention requires supporting mixed prefill/decode subsets within requests, necessitating a non-continuous layout.

    • Added q_lenptr and kv_lenptr to accommodate this functionality (code link).
  2. Horizontal Fusion-Style Implementation
    For improved efficiency, subsets of requests are aware of each other, enabling optimal selection of kernel hyperparameters and persistent kernel execution.

    • Current resource partitioning strategy solely depends on total KV-cache load size (scheduler code).
    • Note: This strategy is customizable based on specific workloads.

Limitations and Future Work

  • CUDA Graph is currently not supported. Only FA2 is supported at this stage.
  • The workload classifier (qo_len > threshold) is preliminary and requires improvement (classifier implementation).
  • Performance tuning is ongoing, and correctness has only been validated on a limited set of unit tests (unit tests).
    cc @AKKamath @yzh119
image

@happierpig happierpig requested a review from yzh119 March 21, 2025 21:28
@yzh119
Copy link
Collaborator

yzh119 commented Mar 21, 2025

Some of the unittests failed, for example (test_block_sparse_attention[False-256-16-16-128-64-16-4])

RuntimeError: Error in function 'PrefillSplitQOKVIndptr' at /workspace/flashinfer/data/include/flashinfer/attention/scheduler.cuh:515: kv_len_ptr_h[0]: 0 should be positive

@Edenzzzz
Copy link
Contributor

Edenzzzz commented Apr 9, 2025

Hi, can I ask when this is planned to be merged? I made a PR to support POD Attn in SGLang using the old API and plan to get that working with CUDA graph first.
sgl-project/sglang#5169

@AKKamath
Copy link
Contributor

AKKamath commented Apr 9, 2025

I really like the uniform batch API that this PR presents.

I ran this on an A100 and compared it with the existing FlashInfer POD-Attention implementation. On average this performed around 10 - 15% worse, but still better than serial execution. Performance was worse for larger prefill context lengths, while for smaller context lengths the performance was more comparable.

@Edenzzzz
Copy link
Contributor

Edenzzzz commented Apr 9, 2025

Yeah this is more convenient, one issue i had during my PR is that I have to fill 2D attention mask for prefill every time, instead using page table & indices

@yzh119
Copy link
Collaborator

yzh119 commented Apr 10, 2025

Hi @Edenzzzz @AKKamath , I'm working on another branch following this idea, it will be merged these days.

@Edenzzzz
Copy link
Contributor

Will the old API be preserved? Thanks.

@Edenzzzz
Copy link
Contributor

@AKKamath Btw, I wonder what was the reason for using a mask instead of page table for prefill qkv?

@AKKamath
Copy link
Contributor

@yzh119 Can correct me here, but I believe the mask prefill kernel (single_prefill) had a better performance than the page table prefill because the page table prefill had a higher register usage causing register spills.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants