Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: dynamic split-k for MLA #863

Merged
merged 5 commits into from
Feb 17, 2025
Merged

perf: dynamic split-k for MLA #863

merged 5 commits into from
Feb 17, 2025

Conversation

yzh119
Copy link
Collaborator

@yzh119 yzh119 commented Feb 17, 2025

#804 didn't implement split-k, which might result in performance degradation if concurrency is not large enough. This PR fixes issue.

We implemented the v2 scheduler and write-through optimization mentioned in our paper (section 3.3 and appendix in D.2) for load-balancing.

In an early PR (#72), we turned off cudaLaunchCooperativeKernels and grid.sync() because we are not sure whether it's compatible with CUDAGraph. This PR adds them back again for grid synchronization, to save some kernel launch overhead.

Benchmark

On H100 SXM5 80GB (3352 GB/s), this PR:

Config: batch_size=1, seq_len=1024, num_heads=16
Memory bandwidth: 22.33 GB/s
Config: batch_size=16, seq_len=1024, num_heads=16
Memory bandwidth: 330.72 GB/s
Config: batch_size=32, seq_len=1024, num_heads=16
Memory bandwidth: 638.73 GB/s
Config: batch_size=64, seq_len=1024, num_heads=16
Memory bandwidth: 1188.90 GB/s
Config: batch_size=1, seq_len=2048, num_heads=16
Memory bandwidth: 40.74 GB/s
Config: batch_size=16, seq_len=2048, num_heads=16
Memory bandwidth: 592.77 GB/s
Config: batch_size=32, seq_len=2048, num_heads=16
Memory bandwidth: 1112.83 GB/s
Config: batch_size=64, seq_len=2048, num_heads=16
Memory bandwidth: 1506.01 GB/s
Config: batch_size=1, seq_len=4096, num_heads=16
Memory bandwidth: 72.53 GB/s
Config: batch_size=16, seq_len=4096, num_heads=16
Memory bandwidth: 1007.80 GB/s
Config: batch_size=32, seq_len=4096, num_heads=16
Memory bandwidth: 1438.99 GB/s
Config: batch_size=64, seq_len=4096, num_heads=16
Memory bandwidth: 1730.62 GB/s
Config: batch_size=1, seq_len=8192, num_heads=16
Memory bandwidth: 120.74 GB/s
Config: batch_size=16, seq_len=8192, num_heads=16
Memory bandwidth: 1340.86 GB/s
Config: batch_size=32, seq_len=8192, num_heads=16
Memory bandwidth: 1689.36 GB/s
Config: batch_size=64, seq_len=8192, num_heads=16
Memory bandwidth: 1901.26 GB/s
Config: batch_size=1, seq_len=16384, num_heads=16
Memory bandwidth: 177.94 GB/s
Config: batch_size=16, seq_len=16384, num_heads=16
Memory bandwidth: 1619.51 GB/s
Config: batch_size=32, seq_len=16384, num_heads=16
Memory bandwidth: 1876.50 GB/s
Config: batch_size=64, seq_len=16384, num_heads=16
Memory bandwidth: 2010.58 GB/s
Config: batch_size=1, seq_len=32768, num_heads=16
Memory bandwidth: 231.70 GB/s
Config: batch_size=16, seq_len=32768, num_heads=16
Memory bandwidth: 1835.16 GB/s
Config: batch_size=32, seq_len=32768, num_heads=16
Memory bandwidth: 1997.24 GB/s
Config: batch_size=64, seq_len=32768, num_heads=16
Memory bandwidth: 2067.99 GB/s

Before this PR:

Config: batch_size=1, seq_len=1024, num_heads=16
Memory bandwidth: 15.46 GB/s
Config: batch_size=16, seq_len=1024, num_heads=16
Memory bandwidth: 238.49 GB/s
Config: batch_size=32, seq_len=1024, num_heads=16
Memory bandwidth: 472.44 GB/s
Config: batch_size=64, seq_len=1024, num_heads=16
Memory bandwidth: 929.12 GB/s
Config: batch_size=1, seq_len=2048, num_heads=16
Memory bandwidth: 15.47 GB/s
Config: batch_size=16, seq_len=2048, num_heads=16
Memory bandwidth: 250.71 GB/s
Config: batch_size=32, seq_len=2048, num_heads=16
Memory bandwidth: 500.21 GB/s
Config: batch_size=64, seq_len=2048, num_heads=16
Memory bandwidth: 996.37 GB/s
Config: batch_size=1, seq_len=4096, num_heads=16
Memory bandwidth: 16.36 GB/s
Config: batch_size=16, seq_len=4096, num_heads=16
Memory bandwidth: 257.59 GB/s
Config: batch_size=32, seq_len=4096, num_heads=16
Memory bandwidth: 515.88 GB/s
Config: batch_size=64, seq_len=4096, num_heads=16
Memory bandwidth: 1035.55 GB/s
Config: batch_size=1, seq_len=8192, num_heads=16
Memory bandwidth: 16.37 GB/s
Config: batch_size=16, seq_len=8192, num_heads=16
Memory bandwidth: 261.47 GB/s
Config: batch_size=32, seq_len=8192, num_heads=16
Memory bandwidth: 524.76 GB/s
Config: batch_size=64, seq_len=8192, num_heads=16
Memory bandwidth: 1054.54 GB/s
Config: batch_size=1, seq_len=16384, num_heads=16
Memory bandwidth: 16.50 GB/s
Config: batch_size=16, seq_len=16384, num_heads=16
Memory bandwidth: 263.69 GB/s
Config: batch_size=32, seq_len=16384, num_heads=16
Memory bandwidth: 528.89 GB/s
Config: batch_size=64, seq_len=16384, num_heads=16
Memory bandwidth: 1064.87 GB/s
Config: batch_size=1, seq_len=32768, num_heads=16
Memory bandwidth: 16.45 GB/s
Config: batch_size=16, seq_len=32768, num_heads=16
Memory bandwidth: 264.66 GB/s
Config: batch_size=32, seq_len=32768, num_heads=16
Memory bandwidth: 530.87 GB/s
Config: batch_size=64, seq_len=32768, num_heads=16
Memory bandwidth: 1070.93 GB/s

@yzh119 yzh119 merged commit 41a4f56 into main Feb 17, 2025
@zhyncs zhyncs deleted the mla-fa2-split-k branch February 17, 2025 17:21
@yzh119
Copy link
Collaborator Author

yzh119 commented Feb 17, 2025

Note that this PR uses f32 as the intermediate data type for output matrix in split-k.

Using 16bit data type (e.g. f16/bf16) might be faster, but will result in some accuracy loss, especially for bf16+long context.

@zhyncs
Copy link
Member

zhyncs commented Feb 17, 2025

This PR is amazing!!

yzh119 added a commit that referenced this pull request Feb 17, 2025
The scheduling algorithm in #863 do not consider some requests have
kv-cache length 0, this PR fixes the issue.
yzh119 added a commit that referenced this pull request Feb 23, 2025
This PR is the followup of #804 , we implemented a FlashAttention-3
version of warp specialization pattern (splitting on head-dimension) in
#804 for faster attention on Hopper GPUs. Compared to the previous
version (in FA2 style), this PR did the following changes:

1. use one warpgroup for producer, two warpgroup for consumer.
2. use async wgmma instead of mma.
3. use the software pipeline algorithm in FlashAttention-3, to overlap
CUDA-Cores and Tensor-Cores operations.
4. Compared to original attention, MLA uses the same set of K and V (the
ckv matrix), if we reuse the `CTA_TILE_KV=64` and `PIPE_STAGES=2`, the
software pipeline algorithm would block the memory copy for next KV-Tile
(both the pipe slots were be occupied), original attention do not have
this issue because it has both `pipeline_k` and `pipeline_v`, doubling
the stages. This PR changes `CTA_TILE_KV=32` and `PIPE_STAGES=4` to
ensure we can compute the current KV-tile while loading the next
KV-Tile, when using software pipeline.
5. Unlike original attention, we can't reuse V shared memory space for
O. This PR designed a circular buffer for `o_smem` that reuses the KV
slots, one KV-slot is not large enough for `o_smem` so we use two KV
shared memory slot for one `o_smem`, a barrier is required to guarantee
the memory order.

## Pipeline

This figures explains our pipeline design:

![pipeline-design-mla](https://github.com/user-attachments/assets/178e465e-e671-459f-a4ea-02e2eaf17343)

## Results

Benchmark result on H100 SXM3 (80GB).

This PR (fa3 template), `page_size=1`:
```
Config: batch_size=64, seq_len=1024, num_heads=64
Memory bandwidth: 1305.40 GB/s
Config: batch_size=128, seq_len=1024, num_heads=64
Memory bandwidth: 2228.56 GB/s
Config: batch_size=768, seq_len=1024, num_heads=64
Memory bandwidth: 2759.33 GB/s
Config: batch_size=64, seq_len=2048, num_heads=64
Memory bandwidth: 1766.33 GB/s
Config: batch_size=128, seq_len=2048, num_heads=64
Memory bandwidth: 2498.08 GB/s
Config: batch_size=768, seq_len=2048, num_heads=64
Memory bandwidth: 2768.37 GB/s
```

#804 + #863 (fa2 template), `page_size=1`:
```
Config: batch_size=64, seq_len=1024, num_heads=64
Memory bandwidth: 1067.74 GB/s
Config: batch_size=128, seq_len=1024, num_heads=64
Memory bandwidth: 1761.25 GB/s
Config: batch_size=768, seq_len=1024, num_heads=64
Memory bandwidth: 2065.78 GB/s
Config: batch_size=64, seq_len=2048, num_heads=64
Memory bandwidth: 1384.35 GB/s
Config: batch_size=128, seq_len=2048, num_heads=64
Memory bandwidth: 1892.64 GB/s
Config: batch_size=768, seq_len=2048, num_heads=64
Memory bandwidth: 2075.97 GB/s
```

Using TMA and multicast could further improve performance for
`page_size` larger than 1, we leave them for future work.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants