Skip to content

Commit 60d37b7

Browse files
authored
perf: Use 2WG pipeline design for MLA implementation on Hopper (#952)
This PR implements #892 . Per benchmark, 2WG pipeline (FlashMLA's implementation) is faster than our current 3WG pipeline design on Hopper. While it remains investigation where the gap comes from, we should implements the 2WG (and 4WG in the future) pipeline in FlashInfer to make sure our implementation not getting worse performance than flashmla. ## Performance Before this PR: ``` Config: batch_size=64, seq_len=1024, num_heads=64 Memory bandwidth: 1547.23 GB/s FLOPs: 167.29 TFLOPs Config: batch_size=64, seq_len=1024, num_heads=128 Memory bandwidth: 1483.82 GB/s FLOPs: 290.23 TFLOPs Config: batch_size=128, seq_len=1024, num_heads=64 Memory bandwidth: 2238.72 GB/s FLOPs: 242.06 TFLOPs Config: batch_size=128, seq_len=1024, num_heads=128 Memory bandwidth: 1612.66 GB/s FLOPs: 315.43 TFLOPs Config: batch_size=768, seq_len=1024, num_heads=64 Memory bandwidth: 2821.32 GB/s FLOPs: 305.05 TFLOPs Config: batch_size=768, seq_len=1024, num_heads=128 Memory bandwidth: 1767.63 GB/s FLOPs: 345.74 TFLOPs Config: batch_size=64, seq_len=2048, num_heads=64 Memory bandwidth: 1960.50 GB/s FLOPs: 223.79 TFLOPs Config: batch_size=64, seq_len=2048, num_heads=128 Memory bandwidth: 1533.88 GB/s FLOPs: 331.70 TFLOPs Config: batch_size=128, seq_len=2048, num_heads=64 Memory bandwidth: 2546.83 GB/s FLOPs: 290.72 TFLOPs Config: batch_size=128, seq_len=2048, num_heads=128 Memory bandwidth: 1629.73 GB/s FLOPs: 352.43 TFLOPs Config: batch_size=768, seq_len=2048, num_heads=64 Memory bandwidth: 2820.22 GB/s FLOPs: 321.93 TFLOPs Config: batch_size=768, seq_len=2048, num_heads=128 Memory bandwidth: 1657.89 GB/s FLOPs: 358.52 TFLOPs Config: batch_size=64, seq_len=8192, num_heads=64 Memory bandwidth: 2682.98 GB/s FLOPs: 319.63 TFLOPs Config: batch_size=64, seq_len=8192, num_heads=128 Memory bandwidth: 1600.79 GB/s FLOPs: 375.94 TFLOPs Config: batch_size=128, seq_len=8192, num_heads=64 Memory bandwidth: 2803.48 GB/s FLOPs: 333.98 TFLOPs Config: batch_size=128, seq_len=8192, num_heads=128 Memory bandwidth: 1584.79 GB/s FLOPs: 372.18 TFLOPs Config: batch_size=768, seq_len=8192, num_heads=64 Memory bandwidth: 2768.36 GB/s FLOPs: 329.80 TFLOPs Config: batch_size=768, seq_len=8192, num_heads=128 Memory bandwidth: 1565.82 GB/s FLOPs: 367.73 TFLOPs ``` After this PR: ``` Config: batch_size=64, seq_len=1024, num_heads=64 Memory bandwidth: 1509.87 GB/s FLOPs: 163.25 TFLOPs Config: batch_size=64, seq_len=1024, num_heads=128 Memory bandwidth: 1766.19 GB/s FLOPs: 345.46 TFLOPs Config: batch_size=128, seq_len=1024, num_heads=64 Memory bandwidth: 2307.97 GB/s FLOPs: 249.55 TFLOPs Config: batch_size=128, seq_len=1024, num_heads=128 Memory bandwidth: 1975.24 GB/s FLOPs: 386.35 TFLOPs Config: batch_size=768, seq_len=1024, num_heads=64 Memory bandwidth: 2871.63 GB/s FLOPs: 310.49 TFLOPs Config: batch_size=768, seq_len=1024, num_heads=128 Memory bandwidth: 2225.07 GB/s FLOPs: 435.21 TFLOPs Config: batch_size=64, seq_len=2048, num_heads=64 Memory bandwidth: 1948.15 GB/s FLOPs: 222.38 TFLOPs Config: batch_size=64, seq_len=2048, num_heads=128 Memory bandwidth: 1973.36 GB/s FLOPs: 426.74 TFLOPs Config: batch_size=128, seq_len=2048, num_heads=64 Memory bandwidth: 2625.63 GB/s FLOPs: 299.72 TFLOPs Config: batch_size=128, seq_len=2048, num_heads=128 Memory bandwidth: 2121.92 GB/s FLOPs: 458.86 TFLOPs Config: batch_size=768, seq_len=2048, num_heads=64 Memory bandwidth: 2996.11 GB/s FLOPs: 342.01 TFLOPs Config: batch_size=768, seq_len=2048, num_heads=128 Memory bandwidth: 2146.40 GB/s FLOPs: 464.16 TFLOPs Config: batch_size=64, seq_len=8192, num_heads=64 Memory bandwidth: 2717.28 GB/s FLOPs: 323.71 TFLOPs Config: batch_size=64, seq_len=8192, num_heads=128 Memory bandwidth: 2129.24 GB/s FLOPs: 500.04 TFLOPs Config: batch_size=128, seq_len=8192, num_heads=64 Memory bandwidth: 3002.75 GB/s FLOPs: 357.72 TFLOPs Config: batch_size=128, seq_len=8192, num_heads=128 Memory bandwidth: 2101.93 GB/s FLOPs: 493.63 TFLOPs Config: batch_size=768, seq_len=8192, num_heads=64 Memory bandwidth: 3083.42 GB/s FLOPs: 367.33 TFLOPs Config: batch_size=768, seq_len=8192, num_heads=128 Memory bandwidth: 2064.96 GB/s FLOPs: 484.95 TFLOPs ``` ## Note 1. Profiler is broken (we changed the pipeline structure, will add them back in later PRs). 2. There is still room for improvement on pipeline design, e.g. we can prefetch next tile first kv-cache, which could further improve performance, we leave it for future work. <img width="1230" alt="image" src="https://github.com/user-attachments/assets/e84b1d55-3361-48a1-b339-97837cb97bfb" /> 3. Synchronization is still sub-optimal, we insert a `__syncthreads()` in each iteration to guarantee correctness, can we further improve this?
1 parent e19cb7b commit 60d37b7

File tree

2 files changed

+195
-249
lines changed

2 files changed

+195
-249
lines changed

benchmarks/bench_deepseek_mla.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def bench_deepseek_mla_decode(batch_size, seq_len, num_heads, backend):
7676

7777

7878
if __name__ == "__main__":
79-
for seq_len in [1024, 2048]:
79+
for seq_len in [1024, 2048, 8192]:
8080
for batch_size in [64, 128, 768]:
81-
bench_deepseek_mla_decode(batch_size, seq_len, 64, "auto")
81+
for num_heads in [64, 128]:
82+
bench_deepseek_mla_decode(batch_size, seq_len, num_heads, "auto")

0 commit comments

Comments
 (0)