-
Notifications
You must be signed in to change notification settings - Fork 266
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
perf: dynamic split-k for MLA #863
Merged
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Note that this PR uses f32 as the intermediate data type for output matrix in split-k. Using 16bit data type (e.g. f16/bf16) might be faster, but will result in some accuracy loss, especially for bf16+long context. |
This PR is amazing!! |
yzh119
added a commit
that referenced
this pull request
Feb 17, 2025
The scheduling algorithm in #863 do not consider some requests have kv-cache length 0, this PR fixes the issue.
yzh119
added a commit
that referenced
this pull request
Feb 23, 2025
This PR is the followup of #804 , we implemented a FlashAttention-3 version of warp specialization pattern (splitting on head-dimension) in #804 for faster attention on Hopper GPUs. Compared to the previous version (in FA2 style), this PR did the following changes: 1. use one warpgroup for producer, two warpgroup for consumer. 2. use async wgmma instead of mma. 3. use the software pipeline algorithm in FlashAttention-3, to overlap CUDA-Cores and Tensor-Cores operations. 4. Compared to original attention, MLA uses the same set of K and V (the ckv matrix), if we reuse the `CTA_TILE_KV=64` and `PIPE_STAGES=2`, the software pipeline algorithm would block the memory copy for next KV-Tile (both the pipe slots were be occupied), original attention do not have this issue because it has both `pipeline_k` and `pipeline_v`, doubling the stages. This PR changes `CTA_TILE_KV=32` and `PIPE_STAGES=4` to ensure we can compute the current KV-tile while loading the next KV-Tile, when using software pipeline. 5. Unlike original attention, we can't reuse V shared memory space for O. This PR designed a circular buffer for `o_smem` that reuses the KV slots, one KV-slot is not large enough for `o_smem` so we use two KV shared memory slot for one `o_smem`, a barrier is required to guarantee the memory order. ## Pipeline This figures explains our pipeline design:  ## Results Benchmark result on H100 SXM3 (80GB). This PR (fa3 template), `page_size=1`: ``` Config: batch_size=64, seq_len=1024, num_heads=64 Memory bandwidth: 1305.40 GB/s Config: batch_size=128, seq_len=1024, num_heads=64 Memory bandwidth: 2228.56 GB/s Config: batch_size=768, seq_len=1024, num_heads=64 Memory bandwidth: 2759.33 GB/s Config: batch_size=64, seq_len=2048, num_heads=64 Memory bandwidth: 1766.33 GB/s Config: batch_size=128, seq_len=2048, num_heads=64 Memory bandwidth: 2498.08 GB/s Config: batch_size=768, seq_len=2048, num_heads=64 Memory bandwidth: 2768.37 GB/s ``` #804 + #863 (fa2 template), `page_size=1`: ``` Config: batch_size=64, seq_len=1024, num_heads=64 Memory bandwidth: 1067.74 GB/s Config: batch_size=128, seq_len=1024, num_heads=64 Memory bandwidth: 1761.25 GB/s Config: batch_size=768, seq_len=1024, num_heads=64 Memory bandwidth: 2065.78 GB/s Config: batch_size=64, seq_len=2048, num_heads=64 Memory bandwidth: 1384.35 GB/s Config: batch_size=128, seq_len=2048, num_heads=64 Memory bandwidth: 1892.64 GB/s Config: batch_size=768, seq_len=2048, num_heads=64 Memory bandwidth: 2075.97 GB/s ``` Using TMA and multicast could further improve performance for `page_size` larger than 1, we leave them for future work.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
#804 didn't implement split-k, which might result in performance degradation if concurrency is not large enough. This PR fixes issue.
We implemented the v2 scheduler and write-through optimization mentioned in our paper (section 3.3 and appendix in D.2) for load-balancing.
In an early PR (#72), we turned off
cudaLaunchCooperativeKernels
andgrid.sync()
because we are not sure whether it's compatible with CUDAGraph. This PR adds them back again for grid synchronization, to save some kernel launch overhead.Benchmark
On H100 SXM5 80GB (3352 GB/s), this PR:
Before this PR: