Skip to content

Commit fa9ee08

Browse files
authored
[Misc] Set default backend to SDPA for get_vit_attn_backend (#12235)
Signed-off-by: wangxiyuan <[email protected]>
1 parent 347eeeb commit fa9ee08

File tree

1 file changed

+16
-14
lines changed

1 file changed

+16
-14
lines changed

vllm/model_executor/models/vision.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -82,23 +82,25 @@ def get_vit_attn_backend(support_fa: bool = False) -> _Backend:
8282
if backend_by_env_var is not None:
8383
selected_backend = backend_name_to_enum(backend_by_env_var)
8484
if selected_backend is None:
85-
# For Volta and Turing GPUs, use xformers instead.
86-
device_available = current_platform.has_device_capability(80)
87-
if device_available and support_fa:
88-
from transformers.utils import is_flash_attn_2_available
89-
if is_flash_attn_2_available():
90-
selected_backend = _Backend.FLASH_ATTN
85+
if current_platform.is_cuda():
86+
device_available = current_platform.has_device_capability(80)
87+
if device_available and support_fa:
88+
from transformers.utils import is_flash_attn_2_available
89+
if is_flash_attn_2_available():
90+
selected_backend = _Backend.FLASH_ATTN
91+
else:
92+
logger.warning_once(
93+
"Current `vllm-flash-attn` has a bug inside vision "
94+
"module, so we use xformers backend instead. You can "
95+
"run `pip install flash-attn` to use flash-attention "
96+
"backend.")
97+
selected_backend = _Backend.XFORMERS
9198
else:
92-
logger.warning_once(
93-
"Current `vllm-flash-attn` has a bug inside vision module, "
94-
"so we use xformers backend instead. You can run "
95-
"`pip install flash-attn` to use flash-attention backend.")
99+
# For Volta and Turing GPUs, use xformers instead.
96100
selected_backend = _Backend.XFORMERS
97-
elif current_platform.is_cpu() or current_platform.is_rocm():
98-
# ROCM doesn't support xformers
99-
selected_backend = _Backend.TORCH_SDPA
100101
else:
101-
selected_backend = _Backend.XFORMERS
102+
# Default to torch SDPA for other non-GPU platforms.
103+
selected_backend = _Backend.TORCH_SDPA
102104
return selected_backend
103105

104106

0 commit comments

Comments
 (0)