Skip to content

Commit 2e0e017

Browse files
authored
[Platform] Add output for Attention Backend (#11981)
Signed-off-by: wangxiyuan <[email protected]>
1 parent 1f18adb commit 2e0e017

File tree

4 files changed

+9
-5
lines changed

4 files changed

+9
-5
lines changed

vllm/attention/backends/abstract.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ class AttentionType:
3131

3232
class AttentionBackend(ABC):
3333
"""Abstract class for attention backends."""
34+
# For some attention backends, we allocate an output tensor before
35+
# calling the custom op. When piecewise cudagraph is enabled, this
36+
# makes sure the output tensor is allocated inside the cudagraph.
37+
accept_output_buffer: bool = False
3438

3539
@staticmethod
3640
@abstractmethod

vllm/attention/backends/flash_attn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929

3030
class FlashAttentionBackend(AttentionBackend):
3131

32+
accept_output_buffer: bool = True
33+
3234
@staticmethod
3335
def get_supported_head_sizes() -> List[int]:
3436
return [32, 64, 96, 128, 160, 192, 224, 256]

vllm/attention/layer.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -110,11 +110,7 @@ def __init__(
110110
self.use_direct_call = not current_platform.is_cuda_alike(
111111
) and not current_platform.is_cpu()
112112

113-
# For some attention backends, we allocate an output tensor before
114-
# calling the custom op. When piecewise cudagraph is enabled, this
115-
# makes sure the output tensor is allocated inside the cudagraph.
116-
self.use_output = self.backend == _Backend.FLASH_ATTN or \
117-
self.backend == _Backend.FLASH_ATTN_VLLM_V1
113+
self.use_output = attn_backend.accept_output_buffer
118114
compilation_config = get_current_vllm_config().compilation_config
119115
if prefix in compilation_config.static_forward_context:
120116
raise ValueError(f"Duplicate layer name: {prefix}")

vllm/v1/attention/backends/flash_attn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
class FlashAttentionBackend(AttentionBackend):
1717

18+
accept_output_buffer: bool = True
19+
1820
@staticmethod
1921
def get_supported_head_sizes() -> List[int]:
2022
return [32, 64, 96, 128, 160, 192, 224, 256]

0 commit comments

Comments
 (0)