Skip to content

perf: Use 2WG pipeline design for MLA implementation on Hopper #952

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Mar 29, 2025

Conversation

yzh119
Copy link
Collaborator

@yzh119 yzh119 commented Mar 17, 2025

This PR implements #892 .

Per benchmark, 2WG pipeline (FlashMLA's implementation) is faster than our current 3WG pipeline design on Hopper. While it remains investigation where the gap comes from, we should implements the 2WG (and 4WG in the future) pipeline in FlashInfer to make sure our implementation not getting worse performance than flashmla.

Performance

Before this PR:

Config: batch_size=64, seq_len=1024, num_heads=64
Memory bandwidth: 1547.23 GB/s
FLOPs: 167.29 TFLOPs
Config: batch_size=64, seq_len=1024, num_heads=128
Memory bandwidth: 1483.82 GB/s
FLOPs: 290.23 TFLOPs
Config: batch_size=128, seq_len=1024, num_heads=64
Memory bandwidth: 2238.72 GB/s
FLOPs: 242.06 TFLOPs
Config: batch_size=128, seq_len=1024, num_heads=128
Memory bandwidth: 1612.66 GB/s
FLOPs: 315.43 TFLOPs
Config: batch_size=768, seq_len=1024, num_heads=64
Memory bandwidth: 2821.32 GB/s
FLOPs: 305.05 TFLOPs
Config: batch_size=768, seq_len=1024, num_heads=128
Memory bandwidth: 1767.63 GB/s
FLOPs: 345.74 TFLOPs
Config: batch_size=64, seq_len=2048, num_heads=64
Memory bandwidth: 1960.50 GB/s
FLOPs: 223.79 TFLOPs
Config: batch_size=64, seq_len=2048, num_heads=128
Memory bandwidth: 1533.88 GB/s
FLOPs: 331.70 TFLOPs
Config: batch_size=128, seq_len=2048, num_heads=64
Memory bandwidth: 2546.83 GB/s
FLOPs: 290.72 TFLOPs
Config: batch_size=128, seq_len=2048, num_heads=128
Memory bandwidth: 1629.73 GB/s
FLOPs: 352.43 TFLOPs
Config: batch_size=768, seq_len=2048, num_heads=64
Memory bandwidth: 2820.22 GB/s
FLOPs: 321.93 TFLOPs
Config: batch_size=768, seq_len=2048, num_heads=128
Memory bandwidth: 1657.89 GB/s
FLOPs: 358.52 TFLOPs
Config: batch_size=64, seq_len=8192, num_heads=64
Memory bandwidth: 2682.98 GB/s
FLOPs: 319.63 TFLOPs
Config: batch_size=64, seq_len=8192, num_heads=128
Memory bandwidth: 1600.79 GB/s
FLOPs: 375.94 TFLOPs
Config: batch_size=128, seq_len=8192, num_heads=64
Memory bandwidth: 2803.48 GB/s
FLOPs: 333.98 TFLOPs
Config: batch_size=128, seq_len=8192, num_heads=128
Memory bandwidth: 1584.79 GB/s
FLOPs: 372.18 TFLOPs
Config: batch_size=768, seq_len=8192, num_heads=64
Memory bandwidth: 2768.36 GB/s
FLOPs: 329.80 TFLOPs
Config: batch_size=768, seq_len=8192, num_heads=128
Memory bandwidth: 1565.82 GB/s
FLOPs: 367.73 TFLOPs

After this PR:

Config: batch_size=64, seq_len=1024, num_heads=64
Memory bandwidth: 1509.87 GB/s
FLOPs: 163.25 TFLOPs
Config: batch_size=64, seq_len=1024, num_heads=128
Memory bandwidth: 1766.19 GB/s
FLOPs: 345.46 TFLOPs
Config: batch_size=128, seq_len=1024, num_heads=64
Memory bandwidth: 2307.97 GB/s
FLOPs: 249.55 TFLOPs
Config: batch_size=128, seq_len=1024, num_heads=128
Memory bandwidth: 1975.24 GB/s
FLOPs: 386.35 TFLOPs
Config: batch_size=768, seq_len=1024, num_heads=64
Memory bandwidth: 2871.63 GB/s
FLOPs: 310.49 TFLOPs
Config: batch_size=768, seq_len=1024, num_heads=128
Memory bandwidth: 2225.07 GB/s
FLOPs: 435.21 TFLOPs
Config: batch_size=64, seq_len=2048, num_heads=64
Memory bandwidth: 1948.15 GB/s
FLOPs: 222.38 TFLOPs
Config: batch_size=64, seq_len=2048, num_heads=128
Memory bandwidth: 1973.36 GB/s
FLOPs: 426.74 TFLOPs
Config: batch_size=128, seq_len=2048, num_heads=64
Memory bandwidth: 2625.63 GB/s
FLOPs: 299.72 TFLOPs
Config: batch_size=128, seq_len=2048, num_heads=128
Memory bandwidth: 2121.92 GB/s
FLOPs: 458.86 TFLOPs
Config: batch_size=768, seq_len=2048, num_heads=64
Memory bandwidth: 2996.11 GB/s
FLOPs: 342.01 TFLOPs
Config: batch_size=768, seq_len=2048, num_heads=128
Memory bandwidth: 2146.40 GB/s
FLOPs: 464.16 TFLOPs
Config: batch_size=64, seq_len=8192, num_heads=64
Memory bandwidth: 2717.28 GB/s
FLOPs: 323.71 TFLOPs
Config: batch_size=64, seq_len=8192, num_heads=128
Memory bandwidth: 2129.24 GB/s
FLOPs: 500.04 TFLOPs
Config: batch_size=128, seq_len=8192, num_heads=64
Memory bandwidth: 3002.75 GB/s
FLOPs: 357.72 TFLOPs
Config: batch_size=128, seq_len=8192, num_heads=128
Memory bandwidth: 2101.93 GB/s
FLOPs: 493.63 TFLOPs
Config: batch_size=768, seq_len=8192, num_heads=64
Memory bandwidth: 3083.42 GB/s
FLOPs: 367.33 TFLOPs
Config: batch_size=768, seq_len=8192, num_heads=128
Memory bandwidth: 2064.96 GB/s
FLOPs: 484.95 TFLOPs

Note

  1. Profiler is broken (we changed the pipeline structure, will add them back in later PRs).
  2. There is still room for improvement on pipeline design, e.g. we can prefetch next tile first kv-cache, which could further improve performance, we leave it for future work.
image 3. Synchronization is still sub-optimal, we insert a `__syncthreads()` in each iteration to guarantee correctness, can we further improve this?

@yzh119 yzh119 changed the title [WIP] Use 2WG pipeline design for MLA implementation on Hopper perf: Use 2WG pipeline design for MLA implementation on Hopper Mar 26, 2025
@yzh119 yzh119 marked this pull request as ready for review March 26, 2025 08:59
@yzh119 yzh119 merged commit 60d37b7 into flashinfer-ai:main Mar 29, 2025
2 checks passed
yzh119 added a commit that referenced this pull request Mar 31, 2025
Followup of #952

cc @abcdabcd987 

## Before this PR
```
Config: batch_size=64, seq_len=1024, num_heads=64
Memory bandwidth: 1509.87 GB/s
FLOPs: 163.25 TFLOPs
Config: batch_size=64, seq_len=1024, num_heads=128
Memory bandwidth: 1766.19 GB/s
FLOPs: 345.46 TFLOPs
Config: batch_size=128, seq_len=1024, num_heads=64
Memory bandwidth: 2307.97 GB/s
FLOPs: 249.55 TFLOPs
Config: batch_size=128, seq_len=1024, num_heads=128
Memory bandwidth: 1975.24 GB/s
FLOPs: 386.35 TFLOPs
Config: batch_size=768, seq_len=1024, num_heads=64
Memory bandwidth: 2871.63 GB/s
FLOPs: 310.49 TFLOPs
Config: batch_size=768, seq_len=1024, num_heads=128
Memory bandwidth: 2225.07 GB/s
FLOPs: 435.21 TFLOPs
Config: batch_size=64, seq_len=2048, num_heads=64
Memory bandwidth: 1948.15 GB/s
FLOPs: 222.38 TFLOPs
Config: batch_size=64, seq_len=2048, num_heads=128
Memory bandwidth: 1973.36 GB/s
FLOPs: 426.74 TFLOPs
Config: batch_size=128, seq_len=2048, num_heads=64
Memory bandwidth: 2625.63 GB/s
FLOPs: 299.72 TFLOPs
Config: batch_size=128, seq_len=2048, num_heads=128
Memory bandwidth: 2121.92 GB/s
FLOPs: 458.86 TFLOPs
Config: batch_size=768, seq_len=2048, num_heads=64
Memory bandwidth: 2996.11 GB/s
FLOPs: 342.01 TFLOPs
Config: batch_size=768, seq_len=2048, num_heads=128
Memory bandwidth: 2146.40 GB/s
FLOPs: 464.16 TFLOPs
Config: batch_size=64, seq_len=8192, num_heads=64
Memory bandwidth: 2717.28 GB/s
FLOPs: 323.71 TFLOPs
Config: batch_size=64, seq_len=8192, num_heads=128
Memory bandwidth: 2129.24 GB/s
FLOPs: 500.04 TFLOPs
Config: batch_size=128, seq_len=8192, num_heads=64
Memory bandwidth: 3002.75 GB/s
FLOPs: 357.72 TFLOPs
Config: batch_size=128, seq_len=8192, num_heads=128
Memory bandwidth: 2101.93 GB/s
FLOPs: 493.63 TFLOPs
Config: batch_size=768, seq_len=8192, num_heads=64
Memory bandwidth: 3083.42 GB/s
FLOPs: 367.33 TFLOPs
Config: batch_size=768, seq_len=8192, num_heads=128
Memory bandwidth: 2064.96 GB/s
FLOPs: 484.95 TFLOPs
```

## After this PR
```
Config: batch_size=64, seq_len=1024, num_heads=64
Memory bandwidth: 1596.98 GB/s
FLOPs: 172.67 TFLOPs
Config: batch_size=64, seq_len=1024, num_heads=128
Memory bandwidth: 1685.22 GB/s
FLOPs: 329.62 TFLOPs
Config: batch_size=128, seq_len=1024, num_heads=64
Memory bandwidth: 2280.49 GB/s
FLOPs: 246.58 TFLOPs
Config: batch_size=128, seq_len=1024, num_heads=128
Memory bandwidth: 1917.53 GB/s
FLOPs: 375.06 TFLOPs
Config: batch_size=768, seq_len=1024, num_heads=64
Memory bandwidth: 2869.03 GB/s
FLOPs: 310.21 TFLOPs
Config: batch_size=768, seq_len=1024, num_heads=128
Memory bandwidth: 2208.35 GB/s
FLOPs: 431.94 TFLOPs
Config: batch_size=64, seq_len=2048, num_heads=64
Memory bandwidth: 2047.44 GB/s
FLOPs: 233.72 TFLOPs
Config: batch_size=64, seq_len=2048, num_heads=128
Memory bandwidth: 1936.08 GB/s
FLOPs: 418.67 TFLOPs
Config: batch_size=128, seq_len=2048, num_heads=64
Memory bandwidth: 2617.48 GB/s
FLOPs: 298.79 TFLOPs
Config: batch_size=128, seq_len=2048, num_heads=128
Memory bandwidth: 2105.97 GB/s
FLOPs: 455.41 TFLOPs
Config: batch_size=768, seq_len=2048, num_heads=64
Memory bandwidth: 2999.55 GB/s
FLOPs: 342.40 TFLOPs
Config: batch_size=768, seq_len=2048, num_heads=128
Memory bandwidth: 2181.54 GB/s
FLOPs: 471.75 TFLOPs
Config: batch_size=64, seq_len=8192, num_heads=64
Memory bandwidth: 2780.86 GB/s
FLOPs: 331.29 TFLOPs
Config: batch_size=64, seq_len=8192, num_heads=128
Memory bandwidth: 2176.12 GB/s
FLOPs: 511.05 TFLOPs
Config: batch_size=128, seq_len=8192, num_heads=64
Memory bandwidth: 3031.58 GB/s
FLOPs: 361.15 TFLOPs
Config: batch_size=128, seq_len=8192, num_heads=128
Memory bandwidth: 2165.73 GB/s
FLOPs: 508.61 TFLOPs
Config: batch_size=768, seq_len=8192, num_heads=64
Memory bandwidth: 3126.37 GB/s
FLOPs: 372.45 TFLOPs
Config: batch_size=768, seq_len=8192, num_heads=128
Memory bandwidth: 2142.42 GB/s
FLOPs: 503.14 TFLOPs
```
MasterJH5574 pushed a commit that referenced this pull request Apr 10, 2025
Follow up of #952 , this PR adds the instrument code base to profile mla
hopper implementation (fix #995 )
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant