@@ -82,23 +82,25 @@ def get_vit_attn_backend(support_fa: bool = False) -> _Backend:
82
82
if backend_by_env_var is not None :
83
83
selected_backend = backend_name_to_enum (backend_by_env_var )
84
84
if selected_backend is None :
85
- # For Volta and Turing GPUs, use xformers instead.
86
- device_available = current_platform .has_device_capability (80 )
87
- if device_available and support_fa :
88
- from transformers .utils import is_flash_attn_2_available
89
- if is_flash_attn_2_available ():
90
- selected_backend = _Backend .FLASH_ATTN
85
+ if current_platform .is_cuda ():
86
+ device_available = current_platform .has_device_capability (80 )
87
+ if device_available and support_fa :
88
+ from transformers .utils import is_flash_attn_2_available
89
+ if is_flash_attn_2_available ():
90
+ selected_backend = _Backend .FLASH_ATTN
91
+ else :
92
+ logger .warning_once (
93
+ "Current `vllm-flash-attn` has a bug inside vision "
94
+ "module, so we use xformers backend instead. You can "
95
+ "run `pip install flash-attn` to use flash-attention "
96
+ "backend." )
97
+ selected_backend = _Backend .XFORMERS
91
98
else :
92
- logger .warning_once (
93
- "Current `vllm-flash-attn` has a bug inside vision module, "
94
- "so we use xformers backend instead. You can run "
95
- "`pip install flash-attn` to use flash-attention backend." )
99
+ # For Volta and Turing GPUs, use xformers instead.
96
100
selected_backend = _Backend .XFORMERS
97
- elif current_platform .is_cpu () or current_platform .is_rocm ():
98
- # ROCM doesn't support xformers
99
- selected_backend = _Backend .TORCH_SDPA
100
101
else :
101
- selected_backend = _Backend .XFORMERS
102
+ # Default to torch SDPA for other non-GPU platforms.
103
+ selected_backend = _Backend .TORCH_SDPA
102
104
return selected_backend
103
105
104
106
0 commit comments