-
Notifications
You must be signed in to change notification settings - Fork 273
Commit 3de690a
authored
feat: unlock MLA attention for sm89 (L40/L40s/4090) (#814)
This PR changes the MLA attention template to support sm89 GPUs, which
has small shared memory size (99kb per sm), so we have to further reduce
shared memory usage: the `NUM_STAGES` can only be set to 1, and
`CTA_TILE_KV` could only be set to atmost 16.
We add an option `QK_SHARD` in the KernelTraits (our previous template
only supports `QK_SHARD=true`):
1. If true, we use the schedule mentioned in #804, and shards the QK
computation on KV dimension, each warpgroup compute half of it, and we
need to perform a round of allgather on shared memory for getting the
full P in PV computation.
2. If false, we duplicate QK computation on two warpgroups (which is not
necessary) but we save the allgather step for P.
We set `QK_SHARD=true` for A100/H100 (shared memory limit is 164kb and
228kb, correspondingly), and `QK_SHARD=false` for sm89.
## Reference
The effect of `QK_SHARD` on H100 SXM5 (3352 GB/s):
```
QK_SHARD=true (Allgather with shared memory)
=======================
Config: batch_size=768, seq_len=1024, num_heads=16
Memory bandwidth: 2010.78 GB/s
Config: batch_size=768, seq_len=1024, num_heads=32
Memory bandwidth: 2036.13 GB/s
Config: batch_size=768, seq_len=1024, num_heads=64
Memory bandwidth: 2085.52 GB/s
Config: batch_size=768, seq_len=2048, num_heads=16
Memory bandwidth: 2068.62 GB/s
Config: batch_size=768, seq_len=2048, num_heads=32
Memory bandwidth: 2085.84 GB/s
Config: batch_size=768, seq_len=2048, num_heads=64
Memory bandwidth: 2080.85 GB/s
QK_SHARD=false (Duplicate P)
=======================
Config: batch_size=768, seq_len=1024, num_heads=16
Memory bandwidth: 1610.81 GB/s
Config: batch_size=768, seq_len=1024, num_heads=32
Memory bandwidth: 1638.73 GB/s
Config: batch_size=768, seq_len=1024, num_heads=64
Memory bandwidth: 1690.86 GB/s
Config: batch_size=768, seq_len=2048, num_heads=16
Memory bandwidth: 1636.08 GB/s
Config: batch_size=768, seq_len=2048, num_heads=32
Memory bandwidth: 1651.57 GB/s
Config: batch_size=768, seq_len=2048, num_heads=64
Memory bandwidth: 1653.31 GB/s
```
The effect of `QK_SHARD` on A100 SXM 40GB (1555 GB/s):
```
QK_SHARD=true (Allgather with shared memory)
=======================
Config: batch_size=768, seq_len=1024, num_heads=16
Memory bandwidth: 891.30 GB/s
Config: batch_size=768, seq_len=1024, num_heads=32
Memory bandwidth: 929.65 GB/s
Config: batch_size=768, seq_len=1024, num_heads=64
Memory bandwidth: 954.24 GB/s
Config: batch_size=768, seq_len=2048, num_heads=16
Memory bandwidth: 923.07 GB/s
Config: batch_size=768, seq_len=2048, num_heads=32
Memory bandwidth: 933.77 GB/s
Config: batch_size=768, seq_len=2048, num_heads=64
Memory bandwidth: 943.48 GB/s
QK_SHARD=false (Duplicate P)
=======================
Config: batch_size=768, seq_len=1024, num_heads=16
Memory bandwidth: 753.89 GB/s
Config: batch_size=768, seq_len=1024, num_heads=32
Memory bandwidth: 780.96 GB/s
Config: batch_size=768, seq_len=1024, num_heads=64
Memory bandwidth: 804.61 GB/s
Config: batch_size=768, seq_len=2048, num_heads=16
Memory bandwidth: 785.70 GB/s
Config: batch_size=768, seq_len=2048, num_heads=32
Memory bandwidth: 796.87 GB/s
Config: batch_size=768, seq_len=2048, num_heads=64
Memory bandwidth: 808.83 GB/s
```1 parent e1880b1 commit 3de690aCopy full SHA for 3de690a
1 file changed
+241
-133
lines changed
0 commit comments