diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py index a08c874407e..492acb91e8e 100644 --- a/tests/kernels/test_attention_selector.py +++ b/tests/kernels/test_attention_selector.py @@ -94,7 +94,12 @@ def test_flash_attn(monkeypatch): def test_invalid_env(monkeypatch): - """Throw an exception if the backend name is invalid.""" + """Ignore the invalid env variable if it is set.""" override_backend_env_variable(monkeypatch, STR_INVALID_VAL) - with pytest.raises(ValueError): - get_attn_backend(16, torch.float16, None, 16, False) + with patch("vllm.attention.selector.current_platform", CudaPlatform()): + backend = get_attn_backend(32, torch.float16, None, 16, False) + assert backend.get_name() == "FLASH_ATTN" + + # when block size == 16, backend will fall back to XFORMERS + backend = get_attn_backend(16, torch.float16, None, 16, False) + assert backend.get_name() == "XFORMERS" diff --git a/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_attention_backend.py b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_attention_backend.py new file mode 100644 index 00000000000..5634be3c8d8 --- /dev/null +++ b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_attention_backend.py @@ -0,0 +1,8 @@ +from vllm.attention.backends.flash_attn import FlashAttentionBackend + + +class DummyAttentionBackend(FlashAttentionBackend): + + @staticmethod + def get_name() -> str: + return "Dummy_Backend" diff --git a/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py index fde93142f11..84721d5971c 100644 --- a/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py +++ b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py @@ -3,3 +3,7 @@ class DummyPlatform(CudaPlatform): device_name = "DummyDevice" + + def get_attn_backend_cls(self, backend_name, head_size, dtype, + kv_cache_dtype, block_size, use_v1): + return "vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend" # noqa E501 diff --git a/tests/plugins_tests/test_platform_plugins.py b/tests/plugins_tests/test_platform_plugins.py index 69698b34c71..661aa5f649a 100644 --- a/tests/plugins_tests/test_platform_plugins.py +++ b/tests/plugins_tests/test_platform_plugins.py @@ -1,3 +1,10 @@ +import torch + +from tests.kernels.utils import override_backend_env_variable +from vllm.attention.selector import get_attn_backend +from vllm.utils import STR_INVALID_VAL + + def test_platform_plugins(): # simulate workload by running an example import runpy @@ -14,3 +21,10 @@ def test_platform_plugins(): f"Expected DummyDevice, got {current_platform.device_name}, " "possibly because current_platform is imported before the plugin" f" is loaded. The first import:\n{_init_trace}") + + +def test_oot_attention_backend(monkeypatch): + # ignore the backend env variable if it is set + override_backend_env_variable(monkeypatch, STR_INVALID_VAL) + backend = get_attn_backend(16, torch.float16, torch.float16, 16, False) + assert backend.get_name() == "Dummy_Backend" diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index a283e87d840..9b03fd73fe6 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -190,11 +190,11 @@ def __init__( kv_cache_dtype=None, block_size=16, is_attention_free=False) - attn_backend = backend_name_to_enum(attn_backend.get_name()) - if attn_backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}: - attn_backend = _Backend.XFORMERS + backend = backend_name_to_enum(attn_backend.get_name()) + if backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}: + backend = _Backend.XFORMERS - self.attn_backend = attn_backend if attn_backend in { + self.attn_backend = backend if backend in { _Backend.TORCH_SDPA, _Backend.XFORMERS } else _Backend.TORCH_SDPA diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 0ff007c87b1..81ea6eefb54 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -14,16 +14,18 @@ logger = init_logger(__name__) -def backend_name_to_enum(backend_name: str) -> _Backend: - assert backend_name is not None - - backend_members = _Backend.__members__ - if backend_name not in backend_members: - raise ValueError(f"Invalid attention backend '{backend_name}'. " - f"Available backends: {', '.join(backend_members)} " - "(case-sensitive).") +def backend_name_to_enum(backend_name: str) -> Optional[_Backend]: + """ + Convert a string backend name to a _Backend enum value. - return _Backend[backend_name] + Returns: + * _Backend: enum value if backend_name is a valid in-tree type + * None: otherwise it's an invalid in-tree type or an out-of-tree platform is + loaded. + """ + assert backend_name is not None + return _Backend[backend_name] if backend_name in _Backend.__members__ else \ + None def get_env_variable_attn_backend() -> Optional[_Backend]: