Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: MLA decode kernel implemented by CuTe targeted to SM80 #844

Merged
merged 4 commits into from
Feb 14, 2025

Conversation

tsu-bin
Copy link
Contributor

@tsu-bin tsu-bin commented Feb 14, 2025

Hi @yzh119 , this is a follow up of #766, an interesting idea came to my mind today, can't help to change few lines to verify this idea.
We can use asymmetric warp config to solve the register file size limit issue, the solution is simply to use 8 warps for the output mma stage, and keep other parts unchanged, because the limitation is on the reg num per cuda thread not the whole SM, there is 64K 32b registers per SM which is enough for the f32 output of 64 heads.
So we now have 4 warps for the att mma stage, 2 warps for the softmax stage, 8 warps for output mma stage, and 4 warps for data load stage, the diagram is updated below:
image

After the change, output mma stage needs more computation, the benchmark drops a little as expected, but still looks good:
image

It seems the performance of this CuTe implementation is slightly better than the current FA2 implementation according to #814
image

So I think this CuTe implementation still has its value, consider such interesting scheduling design and better performance, maybe we can regard it as an ad hoc implementation for (decode only /128 q-heads / SM80) case, and JIT logic can accommodate this kernel.

@yzh119
Copy link
Collaborator

yzh119 commented Feb 14, 2025

Hi @tsu-bin , thank you for the contribution, it look good to me in general.
I tried running this PR with the same tests as MLAPageAttentionWrapper: https://gist.github.com/yzh119/ca06d89574bf5b084339560b926226f6, and the unittests passed (so I suppose this PR is ready to be merged).

I also tried this PR on the same benchmark as #814 (https://gist.github.com/yzh119/4945fe2192792962863649bfb5f218d4), and here is what I get on A100 SXM4 40GB:

Config: batch_size=768, seq_len=1024, num_heads=64
Memory bandwidth: 485.05 GB/s
Config: batch_size=768, seq_len=2048, num_heads=64
Memory bandwidth: 485.72 GB/s

I guess it's because your schedule's parameter is tuned for 4090.

Regarding the performance on 4090, the result for #814 and this PR is not actually comparable because the setting is difference (

thrust::device_vector<T> q_nope(batch_size * num_qo_heads * head_dim_ckv);
thrust::device_vector<T> q_pe(batch_size * num_qo_heads * head_dim_kpe);
thrust::device_vector<T> ckv_data(num_pages * page_size * head_dim_ckv);
thrust::device_vector<T> kpe_data(num_pages * page_size * head_dim_kpe);
) uses zero initilization and #814 uses gaussian initialization (
q_nope = torch.randn(
batch_size * qo_len, num_heads, head_dim_ckv, dtype=torch.half, device="cuda"
)
q_pe = torch.randn(
batch_size * qo_len, num_heads, head_dim_kpe, dtype=torch.half, device="cuda"
)
ckv = torch.randn(
batch_size * kv_len // page_size,
page_size,
head_dim_ckv,
dtype=torch.half,
device="cuda",
)
kpe = torch.randn(
batch_size * kv_len // page_size,
page_size,
head_dim_kpe,
dtype=torch.half,
device="cuda",
)
) which tend to produce lower TFLOPs/s. And the schedule in #814 could be further improved on 4090:

The NUM_STAGES=2 and QK_SHARD=true in this PR could indeed improve performance. The reason I didn't do that in #814 is because I hardcoded the minimal mma $n$ size as 16 while you use $8$ which is more flexible and enable a smaller TILE_SIZE of 8 per warpgroup, and I agree that CuTE is flexible in terms of supporting this.

In general I'm good with this PR, let's move on and unify different schedules in later PRs.

@yzh119 yzh119 merged commit 88e3dee into flashinfer-ai:main Feb 14, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants