Skip to content

Commit 3adf0ff

Browse files
[Platform] Do not raise error if _Backend is not found (#12023)
Signed-off-by: wangxiyuan <[email protected]> Signed-off-by: Mengqing Cao <[email protected]> Co-authored-by: Mengqing Cao <[email protected]>
1 parent ad388d2 commit 3adf0ff

File tree

6 files changed

+49
-16
lines changed

6 files changed

+49
-16
lines changed

tests/kernels/test_attention_selector.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,12 @@ def test_flash_attn(monkeypatch):
9494

9595

9696
def test_invalid_env(monkeypatch):
97-
"""Throw an exception if the backend name is invalid."""
97+
"""Ignore the invalid env variable if it is set."""
9898
override_backend_env_variable(monkeypatch, STR_INVALID_VAL)
99-
with pytest.raises(ValueError):
100-
get_attn_backend(16, torch.float16, None, 16, False)
99+
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
100+
backend = get_attn_backend(32, torch.float16, None, 16, False)
101+
assert backend.get_name() == "FLASH_ATTN"
102+
103+
# when block size == 16, backend will fall back to XFORMERS
104+
backend = get_attn_backend(16, torch.float16, None, 16, False)
105+
assert backend.get_name() == "XFORMERS"
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from vllm.attention.backends.flash_attn import FlashAttentionBackend
2+
3+
4+
class DummyAttentionBackend(FlashAttentionBackend):
5+
6+
@staticmethod
7+
def get_name() -> str:
8+
return "Dummy_Backend"

tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,7 @@
33

44
class DummyPlatform(CudaPlatform):
55
device_name = "DummyDevice"
6+
7+
def get_attn_backend_cls(self, backend_name, head_size, dtype,
8+
kv_cache_dtype, block_size, use_v1):
9+
return "vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend" # noqa E501

tests/plugins_tests/test_platform_plugins.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
import torch
2+
3+
from tests.kernels.utils import override_backend_env_variable
4+
from vllm.attention.selector import get_attn_backend
5+
from vllm.utils import STR_INVALID_VAL
6+
7+
18
def test_platform_plugins():
29
# simulate workload by running an example
310
import runpy
@@ -14,3 +21,10 @@ def test_platform_plugins():
1421
f"Expected DummyDevice, got {current_platform.device_name}, "
1522
"possibly because current_platform is imported before the plugin"
1623
f" is loaded. The first import:\n{_init_trace}")
24+
25+
26+
def test_oot_attention_backend(monkeypatch):
27+
# ignore the backend env variable if it is set
28+
override_backend_env_variable(monkeypatch, STR_INVALID_VAL)
29+
backend = get_attn_backend(16, torch.float16, torch.float16, 16, False)
30+
assert backend.get_name() == "Dummy_Backend"

vllm/attention/layer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -190,11 +190,11 @@ def __init__(
190190
kv_cache_dtype=None,
191191
block_size=16,
192192
is_attention_free=False)
193-
attn_backend = backend_name_to_enum(attn_backend.get_name())
194-
if attn_backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}:
195-
attn_backend = _Backend.XFORMERS
193+
backend = backend_name_to_enum(attn_backend.get_name())
194+
if backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}:
195+
backend = _Backend.XFORMERS
196196

197-
self.attn_backend = attn_backend if attn_backend in {
197+
self.attn_backend = backend if backend in {
198198
_Backend.TORCH_SDPA, _Backend.XFORMERS
199199
} else _Backend.TORCH_SDPA
200200

vllm/attention/selector.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,18 @@
1414
logger = init_logger(__name__)
1515

1616

17-
def backend_name_to_enum(backend_name: str) -> _Backend:
18-
assert backend_name is not None
19-
20-
backend_members = _Backend.__members__
21-
if backend_name not in backend_members:
22-
raise ValueError(f"Invalid attention backend '{backend_name}'. "
23-
f"Available backends: {', '.join(backend_members)} "
24-
"(case-sensitive).")
17+
def backend_name_to_enum(backend_name: str) -> Optional[_Backend]:
18+
"""
19+
Convert a string backend name to a _Backend enum value.
2520
26-
return _Backend[backend_name]
21+
Returns:
22+
* _Backend: enum value if backend_name is a valid in-tree type
23+
* None: otherwise it's an invalid in-tree type or an out-of-tree platform is
24+
loaded.
25+
"""
26+
assert backend_name is not None
27+
return _Backend[backend_name] if backend_name in _Backend.__members__ else \
28+
None
2729

2830

2931
def get_env_variable_attn_backend() -> Optional[_Backend]:

0 commit comments

Comments
 (0)