Skip to content

Commit 3de690a

Browse files
authored
feat: unlock MLA attention for sm89 (L40/L40s/4090) (#814)
This PR changes the MLA attention template to support sm89 GPUs, which has small shared memory size (99kb per sm), so we have to further reduce shared memory usage: the `NUM_STAGES` can only be set to 1, and `CTA_TILE_KV` could only be set to atmost 16. We add an option `QK_SHARD` in the KernelTraits (our previous template only supports `QK_SHARD=true`): 1. If true, we use the schedule mentioned in #804, and shards the QK computation on KV dimension, each warpgroup compute half of it, and we need to perform a round of allgather on shared memory for getting the full P in PV computation. 2. If false, we duplicate QK computation on two warpgroups (which is not necessary) but we save the allgather step for P. We set `QK_SHARD=true` for A100/H100 (shared memory limit is 164kb and 228kb, correspondingly), and `QK_SHARD=false` for sm89. ## Reference The effect of `QK_SHARD` on H100 SXM5 (3352 GB/s): ``` QK_SHARD=true (Allgather with shared memory) ======================= Config: batch_size=768, seq_len=1024, num_heads=16 Memory bandwidth: 2010.78 GB/s Config: batch_size=768, seq_len=1024, num_heads=32 Memory bandwidth: 2036.13 GB/s Config: batch_size=768, seq_len=1024, num_heads=64 Memory bandwidth: 2085.52 GB/s Config: batch_size=768, seq_len=2048, num_heads=16 Memory bandwidth: 2068.62 GB/s Config: batch_size=768, seq_len=2048, num_heads=32 Memory bandwidth: 2085.84 GB/s Config: batch_size=768, seq_len=2048, num_heads=64 Memory bandwidth: 2080.85 GB/s QK_SHARD=false (Duplicate P) ======================= Config: batch_size=768, seq_len=1024, num_heads=16 Memory bandwidth: 1610.81 GB/s Config: batch_size=768, seq_len=1024, num_heads=32 Memory bandwidth: 1638.73 GB/s Config: batch_size=768, seq_len=1024, num_heads=64 Memory bandwidth: 1690.86 GB/s Config: batch_size=768, seq_len=2048, num_heads=16 Memory bandwidth: 1636.08 GB/s Config: batch_size=768, seq_len=2048, num_heads=32 Memory bandwidth: 1651.57 GB/s Config: batch_size=768, seq_len=2048, num_heads=64 Memory bandwidth: 1653.31 GB/s ``` The effect of `QK_SHARD` on A100 SXM 40GB (1555 GB/s): ``` QK_SHARD=true (Allgather with shared memory) ======================= Config: batch_size=768, seq_len=1024, num_heads=16 Memory bandwidth: 891.30 GB/s Config: batch_size=768, seq_len=1024, num_heads=32 Memory bandwidth: 929.65 GB/s Config: batch_size=768, seq_len=1024, num_heads=64 Memory bandwidth: 954.24 GB/s Config: batch_size=768, seq_len=2048, num_heads=16 Memory bandwidth: 923.07 GB/s Config: batch_size=768, seq_len=2048, num_heads=32 Memory bandwidth: 933.77 GB/s Config: batch_size=768, seq_len=2048, num_heads=64 Memory bandwidth: 943.48 GB/s QK_SHARD=false (Duplicate P) ======================= Config: batch_size=768, seq_len=1024, num_heads=16 Memory bandwidth: 753.89 GB/s Config: batch_size=768, seq_len=1024, num_heads=32 Memory bandwidth: 780.96 GB/s Config: batch_size=768, seq_len=1024, num_heads=64 Memory bandwidth: 804.61 GB/s Config: batch_size=768, seq_len=2048, num_heads=16 Memory bandwidth: 785.70 GB/s Config: batch_size=768, seq_len=2048, num_heads=32 Memory bandwidth: 796.87 GB/s Config: batch_size=768, seq_len=2048, num_heads=64 Memory bandwidth: 808.83 GB/s ```
1 parent e1880b1 commit 3de690a

File tree

1 file changed

+241
-133
lines changed

1 file changed

+241
-133
lines changed

0 commit comments

Comments
 (0)