-
Notifications
You must be signed in to change notification settings - Fork 276
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FlashMLA from DeepSeek #892
Comments
I went here for it ! @zhyncs was really fast |
#887 how about this?compare vs https://github.com/deepseek-ai/FlashMLA ? |
The pipeline design is a little bit different from #887, I'll check what we can learn from it. |
@zhyncs @celsowm @MichoChan here is the result I got on H100, by running the latest flashinfer and FlashMLA mainline (the higher the better), for flashinfer we use page_size=1 and FlashMLA uses page_size=64. |
Here's my benchmark code and result on H100: https://gist.github.com/abcdabcd987/b215c5f00f4b5e8399b95d7933bcf475 https://docs.google.com/spreadsheets/d/1t0Txa7Ph9u7Su9LyWpS24vqr9A5FB-FyL0EZNpYOqwg/edit?gid=0#gid=0 Both are using page size 64. FlashMLA is faster in general, way faster on small batch sizes. |
As pointed in #892 (comment) The second stage of split-k seems to have a huge overhead. This PR is the first second in addressing these issues, by changing the vector size from 4 to 8.
Hi @abcdabcd987 , yes I didn't profiled the low batch size use cases, and I just realized we get low performance for small batch and long context. #894 alleviate the issue a little bit. Regarding the cases (qo_len * num_heads >= 128), the current flashinfer implementation is not good at this, because we prioritize |
I found DeepSeek FlashMLA is very very faster than flashinfer when q_head_num equals to 128 (tp1) , almost faster 100% when bs=32. but when q_head_num equals to [16 32 64], faster 10%-20%. |
We will try out the FlashMLA-style warp specialization in the next release. Created an issue for performance tracking: #897 |
As observed in #892 , we found flashinfer mla's second stage of split-k is very slow (when batch size is small), this is because our scheduler only uses one CTA for the second stage of split-k. This PR fixes the issue.
Hello, I noticed the significant speed improvement in the latest test results, but the test script throws errors when running with the new version of FlashInfer. If modifications are needed for the test script? |
@yanghailong-git can you report the error message? |
When running this script https://gist.github.com/abcdabcd987/b215c5f00f4b5e8399b95d7933bcf475 with version v0.2.2.post1, I encountered the error below. How should I resolve this? Thanks. ![]() |
Can you post the full error message in text instead, some key information were clipped in your screenshot. |
The detailed error is as follows:
|
@yanghailong-git #904 should fix it. |
This PR implements #892 . Per benchmark, 2WG pipeline (FlashMLA's implementation) is faster than our current 3WG pipeline design on Hopper. While it remains investigation where the gap comes from, we should implements the 2WG (and 4WG in the future) pipeline in FlashInfer to make sure our implementation not getting worse performance than flashmla. ## Performance Before this PR: ``` Config: batch_size=64, seq_len=1024, num_heads=64 Memory bandwidth: 1547.23 GB/s FLOPs: 167.29 TFLOPs Config: batch_size=64, seq_len=1024, num_heads=128 Memory bandwidth: 1483.82 GB/s FLOPs: 290.23 TFLOPs Config: batch_size=128, seq_len=1024, num_heads=64 Memory bandwidth: 2238.72 GB/s FLOPs: 242.06 TFLOPs Config: batch_size=128, seq_len=1024, num_heads=128 Memory bandwidth: 1612.66 GB/s FLOPs: 315.43 TFLOPs Config: batch_size=768, seq_len=1024, num_heads=64 Memory bandwidth: 2821.32 GB/s FLOPs: 305.05 TFLOPs Config: batch_size=768, seq_len=1024, num_heads=128 Memory bandwidth: 1767.63 GB/s FLOPs: 345.74 TFLOPs Config: batch_size=64, seq_len=2048, num_heads=64 Memory bandwidth: 1960.50 GB/s FLOPs: 223.79 TFLOPs Config: batch_size=64, seq_len=2048, num_heads=128 Memory bandwidth: 1533.88 GB/s FLOPs: 331.70 TFLOPs Config: batch_size=128, seq_len=2048, num_heads=64 Memory bandwidth: 2546.83 GB/s FLOPs: 290.72 TFLOPs Config: batch_size=128, seq_len=2048, num_heads=128 Memory bandwidth: 1629.73 GB/s FLOPs: 352.43 TFLOPs Config: batch_size=768, seq_len=2048, num_heads=64 Memory bandwidth: 2820.22 GB/s FLOPs: 321.93 TFLOPs Config: batch_size=768, seq_len=2048, num_heads=128 Memory bandwidth: 1657.89 GB/s FLOPs: 358.52 TFLOPs Config: batch_size=64, seq_len=8192, num_heads=64 Memory bandwidth: 2682.98 GB/s FLOPs: 319.63 TFLOPs Config: batch_size=64, seq_len=8192, num_heads=128 Memory bandwidth: 1600.79 GB/s FLOPs: 375.94 TFLOPs Config: batch_size=128, seq_len=8192, num_heads=64 Memory bandwidth: 2803.48 GB/s FLOPs: 333.98 TFLOPs Config: batch_size=128, seq_len=8192, num_heads=128 Memory bandwidth: 1584.79 GB/s FLOPs: 372.18 TFLOPs Config: batch_size=768, seq_len=8192, num_heads=64 Memory bandwidth: 2768.36 GB/s FLOPs: 329.80 TFLOPs Config: batch_size=768, seq_len=8192, num_heads=128 Memory bandwidth: 1565.82 GB/s FLOPs: 367.73 TFLOPs ``` After this PR: ``` Config: batch_size=64, seq_len=1024, num_heads=64 Memory bandwidth: 1509.87 GB/s FLOPs: 163.25 TFLOPs Config: batch_size=64, seq_len=1024, num_heads=128 Memory bandwidth: 1766.19 GB/s FLOPs: 345.46 TFLOPs Config: batch_size=128, seq_len=1024, num_heads=64 Memory bandwidth: 2307.97 GB/s FLOPs: 249.55 TFLOPs Config: batch_size=128, seq_len=1024, num_heads=128 Memory bandwidth: 1975.24 GB/s FLOPs: 386.35 TFLOPs Config: batch_size=768, seq_len=1024, num_heads=64 Memory bandwidth: 2871.63 GB/s FLOPs: 310.49 TFLOPs Config: batch_size=768, seq_len=1024, num_heads=128 Memory bandwidth: 2225.07 GB/s FLOPs: 435.21 TFLOPs Config: batch_size=64, seq_len=2048, num_heads=64 Memory bandwidth: 1948.15 GB/s FLOPs: 222.38 TFLOPs Config: batch_size=64, seq_len=2048, num_heads=128 Memory bandwidth: 1973.36 GB/s FLOPs: 426.74 TFLOPs Config: batch_size=128, seq_len=2048, num_heads=64 Memory bandwidth: 2625.63 GB/s FLOPs: 299.72 TFLOPs Config: batch_size=128, seq_len=2048, num_heads=128 Memory bandwidth: 2121.92 GB/s FLOPs: 458.86 TFLOPs Config: batch_size=768, seq_len=2048, num_heads=64 Memory bandwidth: 2996.11 GB/s FLOPs: 342.01 TFLOPs Config: batch_size=768, seq_len=2048, num_heads=128 Memory bandwidth: 2146.40 GB/s FLOPs: 464.16 TFLOPs Config: batch_size=64, seq_len=8192, num_heads=64 Memory bandwidth: 2717.28 GB/s FLOPs: 323.71 TFLOPs Config: batch_size=64, seq_len=8192, num_heads=128 Memory bandwidth: 2129.24 GB/s FLOPs: 500.04 TFLOPs Config: batch_size=128, seq_len=8192, num_heads=64 Memory bandwidth: 3002.75 GB/s FLOPs: 357.72 TFLOPs Config: batch_size=128, seq_len=8192, num_heads=128 Memory bandwidth: 2101.93 GB/s FLOPs: 493.63 TFLOPs Config: batch_size=768, seq_len=8192, num_heads=64 Memory bandwidth: 3083.42 GB/s FLOPs: 367.33 TFLOPs Config: batch_size=768, seq_len=8192, num_heads=128 Memory bandwidth: 2064.96 GB/s FLOPs: 484.95 TFLOPs ``` ## Note 1. Profiler is broken (we changed the pipeline structure, will add them back in later PRs). 2. There is still room for improvement on pipeline design, e.g. we can prefetch next tile first kv-cache, which could further improve performance, we leave it for future work. <img width="1230" alt="image" src="https://github.com/user-attachments/assets/e84b1d55-3361-48a1-b339-97837cb97bfb" /> 3. Synchronization is still sub-optimal, we insert a `__syncthreads()` in each iteration to guarantee correctness, can we further improve this?
Done in #952 . There are some slight discrepancy between current flashinfer's mla implementation (we support any page size and any query length) and flashmla's, but the pipeline structure is the same now. |
as titled
ref https://github.com/deepseek-ai/FlashMLA
The text was updated successfully, but these errors were encountered: