Skip to content

Commit f777816

Browse files
wangxiyuanmzusman
authored andcommitted
[Misc] Pass attention to impl backend (vllm-project#12218)
Signed-off-by: wangxiyuan <[email protected]>
1 parent df01517 commit f777816

12 files changed

+86
-78
lines changed

vllm/attention/backends/abstract.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from abc import ABC, abstractmethod
22
from contextlib import contextmanager
33
from dataclasses import dataclass, fields
4-
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
5-
Tuple, Type, TypeVar)
4+
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional,
5+
Protocol, Set, Tuple, Type, TypeVar)
66

77
import torch
88

@@ -223,6 +223,22 @@ def build(self, seq_lens: List[int], query_lens: List[int],
223223
raise NotImplementedError
224224

225225

226+
class AttentionLayer(Protocol):
227+
228+
_k_scale: float
229+
_v_scale: float
230+
231+
def forward(
232+
self,
233+
query: torch.Tensor,
234+
key: torch.Tensor,
235+
value: torch.Tensor,
236+
kv_cache: torch.Tensor,
237+
attn_metadata: AttentionMetadata,
238+
) -> torch.Tensor:
239+
...
240+
241+
226242
class AttentionImpl(ABC, Generic[T]):
227243

228244
@abstractmethod
@@ -244,13 +260,12 @@ def __init__(
244260
@abstractmethod
245261
def forward(
246262
self,
263+
layer: AttentionLayer,
247264
query: torch.Tensor,
248265
key: torch.Tensor,
249266
value: torch.Tensor,
250267
kv_cache: torch.Tensor,
251268
attn_metadata: T,
252-
k_scale: float = 1.0,
253-
v_scale: float = 1.0,
254269
output: Optional[torch.Tensor] = None,
255270
) -> torch.Tensor:
256271
raise NotImplementedError

vllm/attention/backends/blocksparse_attn.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch
55

66
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
7+
AttentionLayer,
78
AttentionMetadata, AttentionType)
89
from vllm.attention.backends.utils import (CommonAttentionState,
910
CommonMetadataBuilder)
@@ -358,13 +359,12 @@ def __init__(
358359

359360
def forward(
360361
self,
362+
layer: AttentionLayer,
361363
query: torch.Tensor,
362364
key: torch.Tensor,
363365
value: torch.Tensor,
364366
kv_cache: torch.Tensor,
365367
attn_metadata: BlocksparseFlashAttentionMetadata,
366-
k_scale: float = 1.0,
367-
v_scale: float = 1.0,
368368
output: Optional[torch.Tensor] = None,
369369
) -> torch.Tensor:
370370
"""Forward pass with FlashAttention and PagedAttention.
@@ -401,8 +401,8 @@ def forward(
401401
value_cache,
402402
attn_metadata.slot_mapping,
403403
self.kv_cache_dtype,
404-
k_scale,
405-
v_scale,
404+
layer._k_scale,
405+
layer._v_scale,
406406
)
407407

408408
if prefill_meta := attn_metadata.prefill_metadata:
@@ -439,8 +439,8 @@ def forward(
439439
self.num_kv_heads,
440440
self.scale,
441441
self.alibi_slopes,
442-
k_scale,
443-
v_scale,
442+
layer._k_scale,
443+
layer._v_scale,
444444
tp_rank=self.tp_rank,
445445
blocksparse_local_blocks=self.local_blocks,
446446
blocksparse_vert_stride=self.vert_stride,

vllm/attention/backends/flash_attn.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from vllm import _custom_ops as ops
1010
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
11+
AttentionLayer,
1112
AttentionMetadata,
1213
AttentionMetadataBuilder,
1314
AttentionType)
@@ -634,13 +635,12 @@ def __init__(
634635

635636
def forward(
636637
self,
638+
layer: AttentionLayer,
637639
query: torch.Tensor,
638640
key: torch.Tensor,
639641
value: torch.Tensor,
640642
kv_cache: torch.Tensor,
641643
attn_metadata: FlashAttentionMetadata,
642-
k_scale: float = 1.0,
643-
v_scale: float = 1.0,
644644
output: Optional[torch.Tensor] = None,
645645
) -> torch.Tensor:
646646
"""Forward pass with FlashAttention.
@@ -657,7 +657,7 @@ def forward(
657657
NOTE: It in-place updates the output tensor.
658658
"""
659659
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
660-
assert k_scale == 1.0 and v_scale == 1.0, (
660+
assert layer._k_scale == 1.0 and layer._v_scale == 1.0, (
661661
"key/v_scale is not supported in FlashAttention.")
662662

663663
assert output is not None, "Output tensor must be provided."
@@ -709,8 +709,8 @@ def forward(
709709
kv_cache[1],
710710
updated_slot_mapping.flatten(), # type: ignore[union-attr]
711711
kv_cache_dtype,
712-
k_scale,
713-
v_scale,
712+
layer._k_scale,
713+
layer._v_scale,
714714
)
715715

716716
(num_prefill_query_tokens, num_prefill_kv_tokens,

vllm/attention/backends/flashinfer.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import vllm.envs as envs
2424
from vllm import _custom_ops as ops
2525
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
26+
AttentionLayer,
2627
AttentionMetadata,
2728
AttentionMetadataBuilder,
2829
AttentionState, AttentionType)
@@ -792,13 +793,12 @@ def __init__(
792793

793794
def forward(
794795
self,
796+
layer: AttentionLayer,
795797
query: torch.Tensor,
796798
key: torch.Tensor,
797799
value: torch.Tensor,
798800
kv_cache: torch.Tensor,
799801
attn_metadata: FlashInferMetadata,
800-
k_scale: float = 1.0,
801-
v_scale: float = 1.0,
802802
output: Optional[torch.Tensor] = None,
803803
) -> torch.Tensor:
804804

@@ -826,8 +826,8 @@ def forward(
826826
kv_cache[:, 1],
827827
attn_metadata.slot_mapping.flatten(),
828828
kv_cache_dtype,
829-
k_scale,
830-
v_scale,
829+
layer._k_scale,
830+
layer._v_scale,
831831
)
832832
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
833833
# to process the cache when the kv_cache_dtype is fp8
@@ -886,8 +886,8 @@ def forward(
886886
kv_cache,
887887
logits_soft_cap=logits_soft_cap,
888888
causal=True,
889-
k_scale=k_scale,
890-
v_scale=v_scale,
889+
k_scale=layer._k_scale,
890+
v_scale=layer._v_scale,
891891
window_left=window_left)
892892
if decode_meta := attn_metadata.decode_metadata:
893893
assert decode_meta is not None
@@ -897,8 +897,8 @@ def forward(
897897
kv_cache,
898898
sm_scale=softmax_scale,
899899
logits_soft_cap=logits_soft_cap,
900-
k_scale=k_scale,
901-
v_scale=v_scale,
900+
k_scale=layer._k_scale,
901+
v_scale=layer._v_scale,
902902
window_left=window_left)
903903

904904
if prefill_output is None and decode_output is not None:

vllm/attention/backends/hpu_attn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from vllm_hpu_extension.utils import Matmul, Softmax, VLLMKVCache
1212

1313
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
14+
AttentionLayer,
1415
AttentionMetadata, AttentionType)
1516
from vllm.attention.backends.utils import CommonAttentionState
1617
from vllm.attention.ops.hpu_paged_attn import (HPUPagedAttention,
@@ -152,13 +153,12 @@ def __init__(
152153

153154
def forward(
154155
self,
156+
layer: AttentionLayer,
155157
query: torch.Tensor,
156158
key: torch.Tensor,
157159
value: torch.Tensor,
158160
kv_cache: torch.Tensor,
159161
attn_metadata: HPUAttentionMetadata,
160-
k_scale: float = 1.0,
161-
v_scale: float = 1.0,
162162
output: Optional[torch.Tensor] = None,
163163
) -> torch.Tensor:
164164
"""Forward pass with xFormers and PagedAttention.

vllm/attention/backends/ipex_attn.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from vllm._ipex_ops import ipex_ops
99
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
10+
AttentionLayer,
1011
AttentionMetadata, AttentionType)
1112
from vllm.attention.backends.utils import CommonAttentionState
1213
from vllm.attention.ops.paged_attn import (PagedAttention,
@@ -171,13 +172,12 @@ def split_kv_cache(
171172

172173
def forward(
173174
self,
175+
layer: AttentionLayer,
174176
query: torch.Tensor,
175177
key: torch.Tensor,
176178
value: torch.Tensor,
177179
kv_cache: torch.Tensor,
178180
attn_metadata: IpexAttnMetadata, # type: ignore
179-
k_scale: float = 1.0,
180-
v_scale: float = 1.0,
181181
output: Optional[torch.Tensor] = None,
182182
) -> torch.Tensor:
183183
"""Forward pass with IPEX varlen_attention and PagedAttention.
@@ -193,7 +193,7 @@ def forward(
193193
Returns:
194194
shape = [num_tokens, num_heads * head_size]
195195
"""
196-
assert k_scale == 1.0 and v_scale == 1.0
196+
assert layer._k_scale == 1.0 and layer._v_scale == 1.0
197197
num_tokens, hidden_size = query.shape
198198
# Reshape the query, key, and value tensors.
199199
query = query.view(-1, self.num_heads, self.head_size)
@@ -210,8 +210,8 @@ def forward(
210210
value_cache,
211211
attn_metadata.slot_mapping.flatten(),
212212
self.kv_cache_dtype,
213-
k_scale,
214-
v_scale,
213+
layer._k_scale,
214+
layer._v_scale,
215215
)
216216

217217
if attn_metadata.is_prompt:
@@ -296,8 +296,8 @@ def forward(
296296
max_seq_len,
297297
self.alibi_slopes,
298298
self.kv_cache_dtype,
299-
k_scale,
300-
v_scale,
299+
layer._k_scale,
300+
layer._v_scale,
301301
)
302302
else:
303303
# Run PagedAttention V2.
@@ -329,8 +329,8 @@ def forward(
329329
max_seq_len,
330330
self.alibi_slopes,
331331
self.kv_cache_dtype,
332-
k_scale,
333-
v_scale,
332+
layer._k_scale,
333+
layer._v_scale,
334334
)
335335

336336
# Reshape the output tensor.

vllm/attention/backends/pallas.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch_xla.experimental.custom_kernel # Required to register custom ops.
66

77
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
8+
AttentionLayer,
89
AttentionMetadata, AttentionType)
910
from vllm.attention.backends.utils import CommonAttentionState
1011

@@ -150,13 +151,12 @@ def __init__(
150151

151152
def forward(
152153
self,
154+
layer: AttentionLayer,
153155
query: torch.Tensor,
154156
key: torch.Tensor,
155157
value: torch.Tensor,
156158
kv_cache: Tuple[torch.Tensor, torch.Tensor],
157159
attn_metadata: PallasMetadata,
158-
k_scale: float = 1.0,
159-
v_scale: float = 1.0,
160160
output: Optional[torch.Tensor] = None,
161161
) -> torch.Tensor:
162162
"""Forward pass with Pallas attention.
@@ -173,7 +173,7 @@ def forward(
173173
Returns:
174174
shape = [batch_size, seq_len, num_heads * head_size]
175175
"""
176-
assert k_scale == 1.0 and v_scale == 1.0
176+
assert layer._k_scale == 1.0 and layer._v_scale == 1.0
177177
batch_size, seq_len, hidden_size = query.shape
178178
query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
179179
key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size)

vllm/attention/backends/rocm_flash_attn.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import vllm.envs as envs
88
from vllm import _custom_ops as ops
99
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
10+
AttentionLayer,
1011
AttentionMetadata, AttentionType)
1112
from vllm.attention.backends.utils import (CommonAttentionState,
1213
CommonMetadataBuilder)
@@ -414,13 +415,12 @@ def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
414415

415416
def forward(
416417
self,
418+
layer: AttentionLayer,
417419
query: torch.Tensor,
418420
key: torch.Tensor,
419421
value: torch.Tensor,
420422
kv_cache: torch.Tensor,
421423
attn_metadata: ROCmFlashAttentionMetadata,
422-
k_scale: float = 1.0,
423-
v_scale: float = 1.0,
424424
output: Optional[torch.Tensor] = None,
425425
) -> torch.Tensor:
426426
"""Forward pass with FlashAttention and PagedAttention.
@@ -458,8 +458,8 @@ def forward(
458458
value_cache,
459459
attn_metadata.slot_mapping,
460460
self.kv_cache_dtype,
461-
k_scale,
462-
v_scale,
461+
layer._k_scale,
462+
layer._v_scale,
463463
)
464464

465465
num_prefill_tokens = attn_metadata.num_prefill_tokens
@@ -567,8 +567,8 @@ def forward(
567567
prefill_meta.max_query_len,
568568
self.alibi_slopes,
569569
self.sliding_window[0],
570-
k_scale,
571-
v_scale,
570+
layer._k_scale,
571+
layer._v_scale,
572572
)
573573

574574
if decode_meta := attn_metadata.decode_metadata:
@@ -613,8 +613,8 @@ def forward(
613613
max_seq_len,
614614
self.alibi_slopes,
615615
self.kv_cache_dtype,
616-
k_scale,
617-
v_scale,
616+
layer._k_scale,
617+
layer._v_scale,
618618
)
619619
else:
620620
output[num_prefill_tokens:] = PagedAttention.forward_decode(
@@ -628,8 +628,8 @@ def forward(
628628
self.num_kv_heads,
629629
self.scale,
630630
self.alibi_slopes,
631-
k_scale,
632-
v_scale,
631+
layer._k_scale,
632+
layer._v_scale,
633633
)
634634

635635
# Reshape the output tensor.

0 commit comments

Comments
 (0)