-
Notifications
You must be signed in to change notification settings - Fork 266
Commit 60d37b7
authored
perf: Use 2WG pipeline design for MLA implementation on Hopper (#952)
This PR implements #892 .
Per benchmark, 2WG pipeline (FlashMLA's implementation) is faster than
our current 3WG pipeline design on Hopper. While it remains
investigation where the gap comes from, we should implements the 2WG
(and 4WG in the future) pipeline in FlashInfer to make sure our
implementation not getting worse performance than flashmla.
## Performance
Before this PR:
```
Config: batch_size=64, seq_len=1024, num_heads=64
Memory bandwidth: 1547.23 GB/s
FLOPs: 167.29 TFLOPs
Config: batch_size=64, seq_len=1024, num_heads=128
Memory bandwidth: 1483.82 GB/s
FLOPs: 290.23 TFLOPs
Config: batch_size=128, seq_len=1024, num_heads=64
Memory bandwidth: 2238.72 GB/s
FLOPs: 242.06 TFLOPs
Config: batch_size=128, seq_len=1024, num_heads=128
Memory bandwidth: 1612.66 GB/s
FLOPs: 315.43 TFLOPs
Config: batch_size=768, seq_len=1024, num_heads=64
Memory bandwidth: 2821.32 GB/s
FLOPs: 305.05 TFLOPs
Config: batch_size=768, seq_len=1024, num_heads=128
Memory bandwidth: 1767.63 GB/s
FLOPs: 345.74 TFLOPs
Config: batch_size=64, seq_len=2048, num_heads=64
Memory bandwidth: 1960.50 GB/s
FLOPs: 223.79 TFLOPs
Config: batch_size=64, seq_len=2048, num_heads=128
Memory bandwidth: 1533.88 GB/s
FLOPs: 331.70 TFLOPs
Config: batch_size=128, seq_len=2048, num_heads=64
Memory bandwidth: 2546.83 GB/s
FLOPs: 290.72 TFLOPs
Config: batch_size=128, seq_len=2048, num_heads=128
Memory bandwidth: 1629.73 GB/s
FLOPs: 352.43 TFLOPs
Config: batch_size=768, seq_len=2048, num_heads=64
Memory bandwidth: 2820.22 GB/s
FLOPs: 321.93 TFLOPs
Config: batch_size=768, seq_len=2048, num_heads=128
Memory bandwidth: 1657.89 GB/s
FLOPs: 358.52 TFLOPs
Config: batch_size=64, seq_len=8192, num_heads=64
Memory bandwidth: 2682.98 GB/s
FLOPs: 319.63 TFLOPs
Config: batch_size=64, seq_len=8192, num_heads=128
Memory bandwidth: 1600.79 GB/s
FLOPs: 375.94 TFLOPs
Config: batch_size=128, seq_len=8192, num_heads=64
Memory bandwidth: 2803.48 GB/s
FLOPs: 333.98 TFLOPs
Config: batch_size=128, seq_len=8192, num_heads=128
Memory bandwidth: 1584.79 GB/s
FLOPs: 372.18 TFLOPs
Config: batch_size=768, seq_len=8192, num_heads=64
Memory bandwidth: 2768.36 GB/s
FLOPs: 329.80 TFLOPs
Config: batch_size=768, seq_len=8192, num_heads=128
Memory bandwidth: 1565.82 GB/s
FLOPs: 367.73 TFLOPs
```
After this PR:
```
Config: batch_size=64, seq_len=1024, num_heads=64
Memory bandwidth: 1509.87 GB/s
FLOPs: 163.25 TFLOPs
Config: batch_size=64, seq_len=1024, num_heads=128
Memory bandwidth: 1766.19 GB/s
FLOPs: 345.46 TFLOPs
Config: batch_size=128, seq_len=1024, num_heads=64
Memory bandwidth: 2307.97 GB/s
FLOPs: 249.55 TFLOPs
Config: batch_size=128, seq_len=1024, num_heads=128
Memory bandwidth: 1975.24 GB/s
FLOPs: 386.35 TFLOPs
Config: batch_size=768, seq_len=1024, num_heads=64
Memory bandwidth: 2871.63 GB/s
FLOPs: 310.49 TFLOPs
Config: batch_size=768, seq_len=1024, num_heads=128
Memory bandwidth: 2225.07 GB/s
FLOPs: 435.21 TFLOPs
Config: batch_size=64, seq_len=2048, num_heads=64
Memory bandwidth: 1948.15 GB/s
FLOPs: 222.38 TFLOPs
Config: batch_size=64, seq_len=2048, num_heads=128
Memory bandwidth: 1973.36 GB/s
FLOPs: 426.74 TFLOPs
Config: batch_size=128, seq_len=2048, num_heads=64
Memory bandwidth: 2625.63 GB/s
FLOPs: 299.72 TFLOPs
Config: batch_size=128, seq_len=2048, num_heads=128
Memory bandwidth: 2121.92 GB/s
FLOPs: 458.86 TFLOPs
Config: batch_size=768, seq_len=2048, num_heads=64
Memory bandwidth: 2996.11 GB/s
FLOPs: 342.01 TFLOPs
Config: batch_size=768, seq_len=2048, num_heads=128
Memory bandwidth: 2146.40 GB/s
FLOPs: 464.16 TFLOPs
Config: batch_size=64, seq_len=8192, num_heads=64
Memory bandwidth: 2717.28 GB/s
FLOPs: 323.71 TFLOPs
Config: batch_size=64, seq_len=8192, num_heads=128
Memory bandwidth: 2129.24 GB/s
FLOPs: 500.04 TFLOPs
Config: batch_size=128, seq_len=8192, num_heads=64
Memory bandwidth: 3002.75 GB/s
FLOPs: 357.72 TFLOPs
Config: batch_size=128, seq_len=8192, num_heads=128
Memory bandwidth: 2101.93 GB/s
FLOPs: 493.63 TFLOPs
Config: batch_size=768, seq_len=8192, num_heads=64
Memory bandwidth: 3083.42 GB/s
FLOPs: 367.33 TFLOPs
Config: batch_size=768, seq_len=8192, num_heads=128
Memory bandwidth: 2064.96 GB/s
FLOPs: 484.95 TFLOPs
```
## Note
1. Profiler is broken (we changed the pipeline structure, will add them
back in later PRs).
2. There is still room for improvement on pipeline design, e.g. we can
prefetch next tile first kv-cache, which could further improve
performance, we leave it for future work.
<img width="1230" alt="image"
src="https://github.com/user-attachments/assets/e84b1d55-3361-48a1-b339-97837cb97bfb"
/>
3. Synchronization is still sub-optimal, we insert a `__syncthreads()`
in each iteration to guarantee correctness, can we further improve this?1 parent e19cb7b commit 60d37b7Copy full SHA for 60d37b7
File tree
2 files changed
+195
-249
lines changed- benchmarks
- include/flashinfer/attention
2 files changed
+195
-249
lines changedbenchmarks/bench_deepseek_mla.py
Copy file name to clipboardexpand all lines: benchmarks/bench_deepseek_mla.py+3-2
Original file line number | Diff line number | Diff line change | |
---|---|---|---|
| |||
76 | 76 |
| |
77 | 77 |
| |
78 | 78 |
| |
79 |
| - | |
| 79 | + | |
80 | 80 |
| |
81 |
| - | |
| 81 | + | |
| 82 | + |
0 commit comments