Skip to content

Commit 978b45f

Browse files
[Kernel] Flash Attention 3 Support (#12093)
Signed-off-by: Lucas Wilkinson <[email protected]>
1 parent c5b4b11 commit 978b45f

File tree

8 files changed

+150
-82
lines changed

8 files changed

+150
-82
lines changed

CMakeLists.txt

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,6 @@ include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
2424
# Suppress potential warnings about unused manually-specified variables
2525
set(ignoreMe "${VLLM_PYTHON_PATH}")
2626

27-
# Prevent installation of dependencies (cutlass) by default.
28-
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS)
29-
3027
#
3128
# Supported python versions. These versions will be searched in order, the
3229
# first match will be selected. These should be kept in sync with setup.py.
@@ -535,7 +532,7 @@ if(VLLM_GPU_LANG STREQUAL "HIP")
535532
endif()
536533

537534
# vllm-flash-attn currently only supported on CUDA
538-
if (NOT VLLM_TARGET_DEVICE STREQUAL "cuda")
535+
if (NOT VLLM_GPU_LANG STREQUAL "CUDA")
539536
return()
540537
endif ()
541538

@@ -558,7 +555,7 @@ endif()
558555
# They should be identical but if they aren't, this is a massive footgun.
559556
#
560557
# The vllm-flash-attn install rules are nested under vllm to make sure the library gets installed in the correct place.
561-
# To only install vllm-flash-attn, use --component vllm_flash_attn_c.
558+
# To only install vllm-flash-attn, use --component _vllm_fa2_C (for FA2) or --component _vllm_fa3_C (for FA3).
562559
# If no component is specified, vllm-flash-attn is still installed.
563560

564561
# If VLLM_FLASH_ATTN_SRC_DIR is set, vllm-flash-attn is installed from that directory instead of downloading.
@@ -570,43 +567,41 @@ if (DEFINED ENV{VLLM_FLASH_ATTN_SRC_DIR})
570567
endif()
571568

572569
if(VLLM_FLASH_ATTN_SRC_DIR)
573-
FetchContent_Declare(vllm-flash-attn SOURCE_DIR ${VLLM_FLASH_ATTN_SRC_DIR})
570+
FetchContent_Declare(
571+
vllm-flash-attn SOURCE_DIR
572+
${VLLM_FLASH_ATTN_SRC_DIR}
573+
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
574+
)
574575
else()
575576
FetchContent_Declare(
576577
vllm-flash-attn
577578
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
578-
GIT_TAG 96266b1111111f3d11aabefaf3bacbab6a89d03c
579+
GIT_TAG 90eacc1af2a7c3de62ea249e929ed5faccf38954
579580
GIT_PROGRESS TRUE
580581
# Don't share the vllm-flash-attn build between build types
581582
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
582583
)
583584
endif()
584585

585-
# Set the parent build flag so that the vllm-flash-attn library does not redo compile flag and arch initialization.
586-
set(VLLM_PARENT_BUILD ON)
587-
588-
# Ensure the vllm/vllm_flash_attn directory exists before installation
589-
install(CODE "file(MAKE_DIRECTORY \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn\")" COMPONENT vllm_flash_attn_c)
590-
591-
# Make sure vllm-flash-attn install rules are nested under vllm/
592-
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY FALSE)" COMPONENT vllm_flash_attn_c)
593-
install(CODE "set(OLD_CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}\")" COMPONENT vllm_flash_attn_c)
594-
install(CODE "set(CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}/vllm/\")" COMPONENT vllm_flash_attn_c)
595586

596587
# Fetch the vllm-flash-attn library
597588
FetchContent_MakeAvailable(vllm-flash-attn)
598589
message(STATUS "vllm-flash-attn is available at ${vllm-flash-attn_SOURCE_DIR}")
599590

600-
# Restore the install prefix
601-
install(CODE "set(CMAKE_INSTALL_PREFIX \"\${OLD_CMAKE_INSTALL_PREFIX}\")" COMPONENT vllm_flash_attn_c)
602-
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" COMPONENT vllm_flash_attn_c)
591+
# Copy over the vllm-flash-attn python files (duplicated for fa2 and fa3, in
592+
# case only one is built, in the case both are built redundant work is done)
593+
install(
594+
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
595+
DESTINATION vllm_flash_attn
596+
COMPONENT _vllm_fa2_C
597+
FILES_MATCHING PATTERN "*.py"
598+
)
603599

604-
# Copy over the vllm-flash-attn python files
605600
install(
606-
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
607-
DESTINATION vllm/vllm_flash_attn
608-
COMPONENT vllm_flash_attn_c
609-
FILES_MATCHING PATTERN "*.py"
601+
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
602+
DESTINATION vllm_flash_attn
603+
COMPONENT _vllm_fa3_C
604+
FILES_MATCHING PATTERN "*.py"
610605
)
611606

612607
# Nothing after vllm-flash-attn, see comment about macros above

setup.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -228,8 +228,11 @@ def target_name(s: str) -> str:
228228

229229
# CMake appends the extension prefix to the install path,
230230
# and outdir already contains that prefix, so we need to remove it.
231+
# We assume only the final component of extension prefix is added by
232+
# CMake, this is currently true for current extensions but may not
233+
# always be the case.
231234
prefix = outdir
232-
for i in range(ext.name.count('.')):
235+
if '.' in ext.name:
233236
prefix = prefix.parent
234237

235238
# prefix here should actually be the same for all components
@@ -298,7 +301,8 @@ def run(self) -> None:
298301
files_to_copy = [
299302
"vllm/_C.abi3.so",
300303
"vllm/_moe_C.abi3.so",
301-
"vllm/vllm_flash_attn/vllm_flash_attn_c.abi3.so",
304+
"vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so",
305+
"vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so",
302306
"vllm/vllm_flash_attn/flash_attn_interface.py",
303307
"vllm/vllm_flash_attn/__init__.py",
304308
"vllm/cumem_allocator.abi3.so",
@@ -593,8 +597,8 @@ def _read_requirements(filename: str) -> List[str]:
593597
ext_modules.append(CMakeExtension(name="vllm._rocm_C"))
594598

595599
if _is_cuda():
596-
ext_modules.append(
597-
CMakeExtension(name="vllm.vllm_flash_attn.vllm_flash_attn_c"))
600+
ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa2_C"))
601+
ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C"))
598602
ext_modules.append(CMakeExtension(name="vllm.cumem_allocator"))
599603

600604
if _build_custom_ops():

tests/kernels/test_cascade_flash_attn.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def test_merge_kernel(
7878
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
7979
@pytest.mark.parametrize("soft_cap", [None, 50])
8080
@pytest.mark.parametrize("num_blocks", [2048])
81+
@pytest.mark.parametrize("fa_version", [2, 3])
8182
@torch.inference_mode()
8283
def test_cascade(
8384
seq_lens_and_common_prefix: Tuple[List[Tuple[int, int]], int],
@@ -87,8 +88,14 @@ def test_cascade(
8788
block_size: int,
8889
soft_cap: Optional[float],
8990
num_blocks: int,
91+
fa_version: int,
9092
) -> None:
9193
torch.set_default_device("cuda")
94+
if fa_version == 3 and (torch.cuda.get_device_capability() == (8, 6)
95+
or torch.cuda.get_device_capability() == (8, 9)):
96+
pytest.skip("Flash attention version 3 fails on 8.6 and 8.9 due to "
97+
"insufficient shared memory for some shapes")
98+
9299
current_platform.seed_everything(0)
93100

94101
window_size = (-1, -1)
@@ -118,9 +125,7 @@ def test_cascade(
118125
cu_query_lens = torch.tensor([0] + query_lens,
119126
dtype=torch.int32).cumsum(dim=0,
120127
dtype=torch.int32)
121-
cu_kv_lens = torch.tensor([0] + kv_lens,
122-
dtype=torch.int32).cumsum(dim=0,
123-
dtype=torch.int32)
128+
kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32)
124129
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
125130
block_tables = torch.randint(0,
126131
num_blocks,
@@ -140,7 +145,7 @@ def test_cascade(
140145
k=key_cache,
141146
v=value_cache,
142147
cu_seqlens_q=cu_query_lens,
143-
cu_seqlens_k=cu_kv_lens,
148+
seqused_k=kv_lens_tensor,
144149
max_seqlen_q=max_query_len,
145150
max_seqlen_k=max_kv_len,
146151
softmax_scale=scale,
@@ -154,10 +159,8 @@ def test_cascade(
154159
assert all(common_prefix_len < kv_len for kv_len in kv_lens)
155160
cu_prefix_query_lens = torch.tensor([0, total_num_query_tokens],
156161
dtype=torch.int32)
157-
cu_prefix_kv_lens = torch.tensor([0, common_prefix_len], dtype=torch.int32)
158-
cu_suffix_kv_lens = (
159-
cu_kv_lens -
160-
torch.arange(num_seqs + 1, dtype=torch.int32) * common_prefix_len)
162+
prefix_kv_lens = torch.tensor([common_prefix_len], dtype=torch.int32)
163+
suffix_kv_lens = kv_lens_tensor - common_prefix_len
161164
output = torch.empty_like(query)
162165
cascade_attention(
163166
output=output,
@@ -167,15 +170,16 @@ def test_cascade(
167170
cu_query_lens=cu_query_lens,
168171
max_query_len=max_query_len,
169172
cu_prefix_query_lens=cu_prefix_query_lens,
170-
cu_prefix_kv_lens=cu_prefix_kv_lens,
171-
cu_suffix_kv_lens=cu_suffix_kv_lens,
173+
prefix_kv_lens=prefix_kv_lens,
174+
suffix_kv_lens=suffix_kv_lens,
172175
max_kv_len=max_kv_len,
173176
softmax_scale=scale,
174177
alibi_slopes=None,
175178
sliding_window=window_size,
176179
logits_soft_cap=soft_cap if soft_cap is not None else 0,
177180
block_table=block_tables,
178181
common_prefix_len=common_prefix_len,
182+
fa_version=fa_version,
179183
)
180184

181185
# Compare the results.

tests/kernels/test_flash_attn.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def ref_paged_attn(
8080
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
8181
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
8282
@pytest.mark.parametrize("sliding_window", [None, 256])
83+
@pytest.mark.parametrize("fa_version", [2, 3])
8384
@torch.inference_mode()
8485
def test_flash_attn_with_paged_kv(
8586
use_out: bool,
@@ -91,8 +92,14 @@ def test_flash_attn_with_paged_kv(
9192
soft_cap: Optional[float],
9293
num_blocks: int,
9394
sliding_window: Optional[int],
95+
fa_version: int,
9496
) -> None:
9597
torch.set_default_device("cuda")
98+
if fa_version == 3 and (torch.cuda.get_device_capability() == (8, 6)
99+
or torch.cuda.get_device_capability() == (8, 9)):
100+
pytest.skip("Flash attention version 3 fails on 8.6 and 8.9 due to "
101+
"insufficient shared memory for some shapes")
102+
96103
current_platform.seed_everything(0)
97104
num_seqs = len(kv_lens)
98105
num_query_heads = num_heads[0]
@@ -131,6 +138,7 @@ def test_flash_attn_with_paged_kv(
131138
cache_seqlens=kv_lens_tensor,
132139
softcap=soft_cap if soft_cap is not None else 0,
133140
window_size=window_size,
141+
fa_version=fa_version,
134142
)
135143
output = output if not use_out else out
136144
output = output.squeeze(1)
@@ -159,6 +167,7 @@ def test_flash_attn_with_paged_kv(
159167
@pytest.mark.parametrize("dtype", DTYPES)
160168
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
161169
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
170+
@pytest.mark.parametrize("fa_version", [2, 3])
162171
@torch.inference_mode()
163172
def test_varlen_with_paged_kv(
164173
use_out: bool,
@@ -170,8 +179,14 @@ def test_varlen_with_paged_kv(
170179
block_size: int,
171180
soft_cap: Optional[float],
172181
num_blocks: int,
182+
fa_version: int,
173183
) -> None:
174184
torch.set_default_device("cuda")
185+
if fa_version == 3 and (torch.cuda.get_device_capability() == (8, 6)
186+
or torch.cuda.get_device_capability() == (8, 9)):
187+
pytest.skip("Flash attention version 3 fails on 8.6 and 8.9 due to "
188+
"insufficient shared memory for some shapes")
189+
175190
current_platform.seed_everything(0)
176191
num_seqs = len(seq_lens)
177192
query_lens = [x[0] for x in seq_lens]
@@ -198,9 +213,7 @@ def test_varlen_with_paged_kv(
198213
cu_query_lens = torch.tensor([0] + query_lens,
199214
dtype=torch.int32).cumsum(dim=0,
200215
dtype=torch.int32)
201-
cu_kv_lens = torch.tensor([0] + kv_lens,
202-
dtype=torch.int32).cumsum(dim=0,
203-
dtype=torch.int32)
216+
kv_lens = torch.tensor(kv_lens, dtype=torch.int32)
204217

205218
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
206219
block_tables = torch.randint(0,
@@ -215,14 +228,15 @@ def test_varlen_with_paged_kv(
215228
v=value_cache,
216229
out=out,
217230
cu_seqlens_q=cu_query_lens,
218-
cu_seqlens_k=cu_kv_lens,
231+
seqused_k=kv_lens,
219232
max_seqlen_q=max_query_len,
220233
max_seqlen_k=max_kv_len,
221234
softmax_scale=scale,
222235
causal=True,
223236
window_size=window_size,
224237
block_table=block_tables,
225238
softcap=soft_cap if soft_cap is not None else 0,
239+
fa_version=fa_version,
226240
)
227241
output = output if not use_out else out
228242

vllm/attention/backends/flash_attn.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,18 @@
1717
compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens,
1818
get_seq_len_block_table_args, is_all_cross_attn_metadata_set,
1919
is_all_encoder_attn_metadata_set, is_block_tables_empty)
20+
from vllm.envs import VLLM_FLASH_ATTN_VERSION
2021
from vllm.multimodal import MultiModalPlaceholderMap
22+
from vllm.platforms import current_platform
2123
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
2224

2325
if TYPE_CHECKING:
2426
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
2527
ModelInputForGPUWithSamplingMetadata)
2628

2729
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
28-
flash_attn_with_kvcache)
30+
flash_attn_with_kvcache,
31+
is_fa_version_supported)
2932

3033

3134
class FlashAttentionBackend(AttentionBackend):
@@ -634,6 +637,20 @@ def __init__(
634637
f"Supported head sizes are: {support_head_sizes}.")
635638
self.attn_type = attn_type
636639

640+
# if hopper default to FA3, otherwise stick to FA2 for now
641+
# TODO(lucas): profile FA3 on ampere to see if it makes sense to
642+
# use FA3 as default for both
643+
if current_platform.get_device_capability()[0] >= 9:
644+
self.fa_version = 3 if is_fa_version_supported(3) else 2
645+
else:
646+
self.fa_version = 2
647+
648+
if VLLM_FLASH_ATTN_VERSION is not None:
649+
assert VLLM_FLASH_ATTN_VERSION in [2, 3]
650+
self.fa_version = VLLM_FLASH_ATTN_VERSION
651+
652+
assert is_fa_version_supported(self.fa_version)
653+
637654
def forward(
638655
self,
639656
layer: AttentionLayer,
@@ -752,6 +769,7 @@ def forward(
752769
alibi_slopes=alibi_slopes,
753770
softcap=logits_soft_cap,
754771
out=prefill_output,
772+
fa_version=self.fa_version,
755773
)
756774
else:
757775
# prefix-enabled attention
@@ -765,7 +783,7 @@ def forward(
765783
v=value_cache,
766784
cu_seqlens_q=prefill_meta.query_start_loc,
767785
max_seqlen_q=prefill_meta.max_query_len,
768-
cu_seqlens_k=prefill_meta.seq_start_loc,
786+
seqused_k=prefill_meta.seq_lens_tensor,
769787
max_seqlen_k=max_seq_len,
770788
softmax_scale=softmax_scale,
771789
causal=True,
@@ -774,6 +792,7 @@ def forward(
774792
block_table=prefill_meta.block_tables,
775793
softcap=logits_soft_cap,
776794
out=prefill_output,
795+
fa_version=self.fa_version,
777796
)
778797

779798
if decode_meta := attn_metadata.decode_metadata:
@@ -793,7 +812,7 @@ def forward(
793812
v=value_cache,
794813
cu_seqlens_q=decode_meta.query_start_loc,
795814
max_seqlen_q=decode_meta.max_decode_query_len,
796-
cu_seqlens_k=decode_meta.seq_start_loc,
815+
seqused_k=decode_meta.seq_lens_tensor,
797816
max_seqlen_k=decode_meta.max_decode_seq_len,
798817
softmax_scale=softmax_scale,
799818
causal=True,
@@ -802,6 +821,7 @@ def forward(
802821
softcap=logits_soft_cap,
803822
block_table=decode_meta.block_tables,
804823
out=decode_output,
824+
fa_version=self.fa_version,
805825
)
806826
else:
807827
# Use flash_attn_with_kvcache for normal decoding.
@@ -822,6 +842,7 @@ def forward(
822842
alibi_slopes=alibi_slopes,
823843
softcap=logits_soft_cap,
824844
out=decode_output.unsqueeze(1),
845+
fa_version=self.fa_version,
825846
)
826847
return output
827848

0 commit comments

Comments
 (0)