Skip to content

Commit 1929795

Browse files
mgoinaurickq
authored andcommitted
[V1] V1 FlashInfer Attention (vllm-project#16684)
Signed-off-by: mgoin <[email protected]> Co-authored-by: Aurick Qiao <[email protected]>
1 parent 07e8588 commit 1929795

File tree

7 files changed

+668
-13
lines changed

7 files changed

+668
-13
lines changed

tests/v1/e2e/test_cascade_attention.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,21 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3+
import pytest
4+
35
from vllm import LLM, SamplingParams
46

7+
from ...utils import fork_new_process_for_each_test
8+
59

6-
def test_cascade_attention(example_system_message, monkeypatch):
10+
@fork_new_process_for_each_test
11+
@pytest.mark.parametrize("attn_backend",
12+
["FLASH_ATTN_VLLM_V1", "FLASHINFER_VLLM_V1"])
13+
def test_cascade_attention(example_system_message, monkeypatch, attn_backend):
714
prompt = "\n<User>: Implement fibonacci sequence in Python.\n<Claude>:"
815

916
with monkeypatch.context() as m:
1017
m.setenv("VLLM_USE_V1", "1")
18+
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
1119

1220
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct")
1321
sampling_params = SamplingParams(temperature=0.0, max_tokens=100)

vllm/engine/arg_utils.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -1474,10 +1474,17 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
14741474
recommend_to_remove=False)
14751475
return False
14761476

1477-
# No FlashInfer or XFormers so far.
1477+
# No XFormers so far.
14781478
V1_BACKENDS = [
1479-
"FLASH_ATTN_VLLM_V1", "FLASH_ATTN", "PALLAS", "PALLAS_VLLM_V1",
1480-
"TRITON_ATTN_VLLM_V1", "TRITON_MLA", "FLASHMLA"
1479+
"FLASH_ATTN_VLLM_V1",
1480+
"FLASH_ATTN",
1481+
"PALLAS",
1482+
"PALLAS_VLLM_V1",
1483+
"TRITON_ATTN_VLLM_V1",
1484+
"TRITON_MLA",
1485+
"FLASHMLA",
1486+
"FLASHINFER",
1487+
"FLASHINFER_VLLM_V1",
14811488
]
14821489
if (envs.is_set("VLLM_ATTENTION_BACKEND")
14831490
and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS):

vllm/platforms/cuda.py

+3
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,9 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
213213
return ("vllm.attention.backends."
214214
"flashmla.FlashMLABackend")
215215
if use_v1:
216+
if selected_backend == _Backend.FLASHINFER:
217+
logger.info_once("Using FlashInfer backend on V1 engine.")
218+
return "vllm.v1.attention.backends.flashinfer.FlashInferBackend"
216219
if selected_backend == _Backend.TRITON_ATTN_VLLM_V1:
217220
logger.info_once("Using Triton backend on V1 engine.")
218221
return ("vllm.v1.attention.backends."

vllm/v1/attention/backends/flash_attn.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,6 @@ def get_kv_cache_shape(
6464
raise ValueError("Block size must be a multiple of 16.")
6565
return (2, num_blocks, block_size, num_kv_heads, head_size)
6666

67-
@staticmethod
68-
def use_cascade_attention(*args, **kwargs) -> bool:
69-
return use_cascade_attention(*args, **kwargs)
70-
7167

7268
@dataclass
7369
class FlashAttentionMetadata:
@@ -402,6 +398,9 @@ def schedule(cu_query_lens, max_query_len, seqlens, max_seq_len,
402398
)
403399
return attn_metadata
404400

401+
def use_cascade_attention(self, *args, **kwargs) -> bool:
402+
return use_cascade_attention(*args, **kwargs)
403+
405404

406405
class FlashAttentionImpl(AttentionImpl):
407406

0 commit comments

Comments
 (0)