Open
Description
This issue is the followup of #887. Per #892 (comment), we found flashinfer's MLA implementation is slower than FlashMLA in a lot of cases, we create this issue to track the remaining items to improve flashinfer MLA performance (mainly for Hopper):
Performance Tracking Table
Contributed by @abcdabcd987 :
https://docs.google.com/spreadsheets/d/1t0Txa7Ph9u7Su9LyWpS24vqr9A5FB-FyL0EZNpYOqwg/edit?gid=0#gid=0
Checklist
- Slower on low batch-size (mainly because of split-k)
- Second stage of split-k is slow because of vector size. perf: fix the performance of second stage of split-k #894
- Load imbalance in the second stage of split-k perf: fix MLA split-k performance bug #898
- Use shared memory to accelerate second stage of split-k.
- Try using bf16/fp16 as partial output data type (currently we use fp32, which is better for accuracy but results in more memory transaction). perf: use f16 as split-k partial output data type #900
- Slower for qo_len * head_dim > 64 (We split on
qo_len * head_dim
by a tile size of 64, different query tiles are dispatched to different CTAs, we need to improve the KV-Cache access pattern for 2 CTAs with the cluster).- Use cluster sync to increase L2 hit rate.
- Use TMA and multi-casting for
page_size >= 16
- Try Different pipeline design
- Try FlashMLA-style warp specialization: FlashMLA and perf: FlashAttention-3 style MLA PageAttention #887 use different pipeline and warp specialization designs, more specifically perf: Use 2WG pipeline design for MLA implementation on Hopper #952:
- Both FlashMLA and FlashInfer split PV on head-dimension, but FlashMLA do not split QK and FlashInfer split QK on KV dimension.
- FlashMLA uses two warpgroups, one for QK and PV1, another one for data loading and PV2.
- FlashInfer uses three warpgroups, one for data loading, one for QK1 and PV1, one for QK2 and PV2.
- We should try FlashMLA-style warp specialization and check which one is better.
- Another possible warp specialization design is to introduce another warpgroup for QK: one for data loading, one for QK, one for PV1, one for PV2.
- Try FlashMLA-style warp specialization: FlashMLA and perf: FlashAttention-3 style MLA PageAttention #887 use different pipeline and warp specialization designs, more specifically perf: Use 2WG pipeline design for MLA implementation on Hopper #952:
- Misc items
- Register allocation optimizations perf: tweak register amount for producer/consumer in MLA template #896
- Defer synchronize
p_smem
and change unroll number perf: tweak the pipeline design of mla kernel #901
Metadata
Metadata
Assignees
Labels
No labels