Skip to content

Commit b168424

Browse files
LucasWilkinsontjtanaa
authored andcommitted
[Bugfix][Kernel] Fix CUDA 11.8 being broken by FA3 build (vllm-project#12375)
Signed-off-by: Lucas Wilkinson <[email protected]>
1 parent 238f125 commit b168424

File tree

6 files changed

+42
-22
lines changed

6 files changed

+42
-22
lines changed

CMakeLists.txt

100644100755
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -576,7 +576,7 @@ else()
576576
FetchContent_Declare(
577577
vllm-flash-attn
578578
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
579-
GIT_TAG 90eacc1af2a7c3de62ea249e929ed5faccf38954
579+
GIT_TAG 0aff05f577e8a10086066a00618609199b25231d
580580
GIT_PROGRESS TRUE
581581
# Don't share the vllm-flash-attn build between build types
582582
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn

setup.py

100644100755
Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -598,7 +598,10 @@ def _read_requirements(filename: str) -> List[str]:
598598

599599
if _is_cuda():
600600
ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa2_C"))
601-
ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C"))
601+
if envs.VLLM_USE_PRECOMPILED or get_nvcc_cuda_version() >= Version("12.0"):
602+
# FA3 requires CUDA 12.0 or later
603+
ext_modules.append(
604+
CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C"))
602605
ext_modules.append(CMakeExtension(name="vllm.cumem_allocator"))
603606

604607
if _build_custom_ops():

tests/kernels/test_cascade_flash_attn.py

100644100755
Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
from vllm.platforms import current_platform
77
from vllm.v1.attention.backends.flash_attn import (cascade_attention,
88
merge_attn_states)
9-
from vllm.vllm_flash_attn import flash_attn_varlen_func
9+
from vllm.vllm_flash_attn import (fa_version_unsupported_reason,
10+
flash_attn_varlen_func,
11+
is_fa_version_supported)
1012

1113
NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
1214
HEAD_SIZES = [128, 192, 256]
@@ -91,10 +93,9 @@ def test_cascade(
9193
fa_version: int,
9294
) -> None:
9395
torch.set_default_device("cuda")
94-
if fa_version == 3 and (torch.cuda.get_device_capability() == (8, 6)
95-
or torch.cuda.get_device_capability() == (8, 9)):
96-
pytest.skip("Flash attention version 3 fails on 8.6 and 8.9 due to "
97-
"insufficient shared memory for some shapes")
96+
if not is_fa_version_supported(fa_version):
97+
pytest.skip(f"Flash attention version {fa_version} not supported due "
98+
f"to: \"{fa_version_unsupported_reason(fa_version)}\"")
9899

99100
current_platform.seed_everything(0)
100101

tests/kernels/test_flash_attn.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
import torch
55

66
from vllm.platforms import current_platform
7-
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
8-
flash_attn_with_kvcache)
7+
from vllm.vllm_flash_attn import (fa_version_unsupported_reason,
8+
flash_attn_varlen_func,
9+
flash_attn_with_kvcache,
10+
is_fa_version_supported)
911

1012
NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
1113
HEAD_SIZES = [128, 256]
@@ -95,10 +97,9 @@ def test_flash_attn_with_paged_kv(
9597
fa_version: int,
9698
) -> None:
9799
torch.set_default_device("cuda")
98-
if fa_version == 3 and (torch.cuda.get_device_capability() == (8, 6)
99-
or torch.cuda.get_device_capability() == (8, 9)):
100-
pytest.skip("Flash attention version 3 fails on 8.6 and 8.9 due to "
101-
"insufficient shared memory for some shapes")
100+
if not is_fa_version_supported(fa_version):
101+
pytest.skip(f"Flash attention version {fa_version} not supported due "
102+
f"to: \"{fa_version_unsupported_reason(fa_version)}\"")
102103

103104
current_platform.seed_everything(0)
104105
num_seqs = len(kv_lens)
@@ -182,11 +183,9 @@ def test_varlen_with_paged_kv(
182183
fa_version: int,
183184
) -> None:
184185
torch.set_default_device("cuda")
185-
if fa_version == 3 and (torch.cuda.get_device_capability() == (8, 6)
186-
or torch.cuda.get_device_capability() == (8, 9)):
187-
pytest.skip("Flash attention version 3 fails on 8.6 and 8.9 due to "
188-
"insufficient shared memory for some shapes")
189-
186+
if not is_fa_version_supported(fa_version):
187+
pytest.skip(f"Flash attention version {fa_version} not supported due "
188+
f"to: \"{fa_version_unsupported_reason(fa_version)}\"")
190189
current_platform.seed_everything(0)
191190
num_seqs = len(seq_lens)
192191
query_lens = [x[0] for x in seq_lens]

vllm/attention/backends/flash_attn.py

100644100755
Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,20 @@
1818
get_seq_len_block_table_args, is_all_cross_attn_metadata_set,
1919
is_all_encoder_attn_metadata_set, is_block_tables_empty)
2020
from vllm.envs import VLLM_FLASH_ATTN_VERSION
21+
from vllm.logger import init_logger
2122
from vllm.multimodal import MultiModalPlaceholderMap
2223
from vllm.platforms import current_platform
2324
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
25+
from vllm.vllm_flash_attn import (fa_version_unsupported_reason,
26+
flash_attn_varlen_func,
27+
flash_attn_with_kvcache,
28+
is_fa_version_supported)
2429

2530
if TYPE_CHECKING:
2631
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
2732
ModelInputForGPUWithSamplingMetadata)
2833

29-
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
30-
flash_attn_with_kvcache,
31-
is_fa_version_supported)
34+
logger = init_logger(__name__)
3235

3336

3437
class FlashAttentionBackend(AttentionBackend):
@@ -652,6 +655,11 @@ def __init__(
652655
assert VLLM_FLASH_ATTN_VERSION in [2, 3]
653656
self.fa_version = VLLM_FLASH_ATTN_VERSION
654657

658+
if not is_fa_version_supported(self.fa_version):
659+
logger.error("Cannot use FA version %d is not supported due to %s",
660+
self.fa_version,
661+
fa_version_unsupported_reason(self.fa_version))
662+
655663
assert is_fa_version_supported(self.fa_version)
656664

657665
def forward(

vllm/v1/attention/backends/flash_attn.py

100644100755
Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,15 @@
1010
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
1111
AttentionMetadata, AttentionType)
1212
from vllm.envs import VLLM_FLASH_ATTN_VERSION
13+
from vllm.logger import init_logger
1314
from vllm.platforms import current_platform
1415
from vllm.utils import cdiv
15-
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
16+
from vllm.vllm_flash_attn import (fa_version_unsupported_reason,
17+
flash_attn_varlen_func,
1618
is_fa_version_supported)
1719

20+
logger = init_logger(__name__)
21+
1822

1923
class FlashAttentionBackend(AttentionBackend):
2024

@@ -143,6 +147,11 @@ def __init__(
143147
assert VLLM_FLASH_ATTN_VERSION in [2, 3]
144148
self.fa_version = VLLM_FLASH_ATTN_VERSION
145149

150+
if not is_fa_version_supported(self.fa_version):
151+
logger.error("Cannot use FA version %d is not supported due to %s",
152+
self.fa_version,
153+
fa_version_unsupported_reason(self.fa_version))
154+
146155
assert is_fa_version_supported(self.fa_version)
147156

148157
def forward(

0 commit comments

Comments
 (0)