Skip to content

Commit 6a5a4e5

Browse files
committed
[Misc] Pass attention to impl backend
Signed-off-by: wangxiyuan <[email protected]>
1 parent 5c89a29 commit 6a5a4e5

12 files changed

+59
-76
lines changed

vllm/attention/backends/abstract.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,13 +244,12 @@ def __init__(
244244
@abstractmethod
245245
def forward(
246246
self,
247+
layer: torch.nn.Module,
247248
query: torch.Tensor,
248249
key: torch.Tensor,
249250
value: torch.Tensor,
250251
kv_cache: torch.Tensor,
251252
attn_metadata: T,
252-
k_scale: float = 1.0,
253-
v_scale: float = 1.0,
254253
output: Optional[torch.Tensor] = None,
255254
) -> torch.Tensor:
256255
raise NotImplementedError

vllm/attention/backends/blocksparse_attn.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -358,13 +358,12 @@ def __init__(
358358

359359
def forward(
360360
self,
361+
layer: torch.nn.Module,
361362
query: torch.Tensor,
362363
key: torch.Tensor,
363364
value: torch.Tensor,
364365
kv_cache: torch.Tensor,
365366
attn_metadata: BlocksparseFlashAttentionMetadata,
366-
k_scale: float = 1.0,
367-
v_scale: float = 1.0,
368367
output: Optional[torch.Tensor] = None,
369368
) -> torch.Tensor:
370369
"""Forward pass with FlashAttention and PagedAttention.
@@ -401,8 +400,8 @@ def forward(
401400
value_cache,
402401
attn_metadata.slot_mapping,
403402
self.kv_cache_dtype,
404-
k_scale,
405-
v_scale,
403+
layer._k_scale,
404+
layer._v_scale,
406405
)
407406

408407
if prefill_meta := attn_metadata.prefill_metadata:
@@ -439,8 +438,8 @@ def forward(
439438
self.num_kv_heads,
440439
self.scale,
441440
self.alibi_slopes,
442-
k_scale,
443-
v_scale,
441+
layer._k_scale,
442+
layer._v_scale,
444443
tp_rank=self.tp_rank,
445444
blocksparse_local_blocks=self.local_blocks,
446445
blocksparse_vert_stride=self.vert_stride,

vllm/attention/backends/flash_attn.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -634,13 +634,12 @@ def __init__(
634634

635635
def forward(
636636
self,
637+
layer: torch.nn.Module,
637638
query: torch.Tensor,
638639
key: torch.Tensor,
639640
value: torch.Tensor,
640641
kv_cache: torch.Tensor,
641642
attn_metadata: FlashAttentionMetadata,
642-
k_scale: float = 1.0,
643-
v_scale: float = 1.0,
644643
output: Optional[torch.Tensor] = None,
645644
) -> torch.Tensor:
646645
"""Forward pass with FlashAttention.
@@ -657,7 +656,7 @@ def forward(
657656
NOTE: It in-place updates the output tensor.
658657
"""
659658
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
660-
assert k_scale == 1.0 and v_scale == 1.0, (
659+
assert layer._k_scale == 1.0 and layer._v_scale == 1.0, (
661660
"key/v_scale is not supported in FlashAttention.")
662661

663662
assert output is not None, "Output tensor must be provided."
@@ -709,8 +708,8 @@ def forward(
709708
kv_cache[1],
710709
updated_slot_mapping.flatten(), # type: ignore[union-attr]
711710
kv_cache_dtype,
712-
k_scale,
713-
v_scale,
711+
layer._k_scale,
712+
layer._v_scale,
714713
)
715714

716715
(num_prefill_query_tokens, num_prefill_kv_tokens,

vllm/attention/backends/flashinfer.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -792,13 +792,12 @@ def __init__(
792792

793793
def forward(
794794
self,
795+
layer: torch.nn.Module,
795796
query: torch.Tensor,
796797
key: torch.Tensor,
797798
value: torch.Tensor,
798799
kv_cache: torch.Tensor,
799800
attn_metadata: FlashInferMetadata,
800-
k_scale: float = 1.0,
801-
v_scale: float = 1.0,
802801
output: Optional[torch.Tensor] = None,
803802
) -> torch.Tensor:
804803

@@ -826,8 +825,8 @@ def forward(
826825
kv_cache[:, 1],
827826
attn_metadata.slot_mapping.flatten(),
828827
kv_cache_dtype,
829-
k_scale,
830-
v_scale,
828+
layer._k_scale,
829+
layer._v_scale,
831830
)
832831
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
833832
# to process the cache when the kv_cache_dtype is fp8
@@ -886,8 +885,8 @@ def forward(
886885
kv_cache,
887886
logits_soft_cap=logits_soft_cap,
888887
causal=True,
889-
k_scale=k_scale,
890-
v_scale=v_scale,
888+
k_scale=layer._k_scale,
889+
v_scale=layer._v_scale,
891890
window_left=window_left)
892891
if decode_meta := attn_metadata.decode_metadata:
893892
assert decode_meta is not None
@@ -897,8 +896,8 @@ def forward(
897896
kv_cache,
898897
sm_scale=softmax_scale,
899898
logits_soft_cap=logits_soft_cap,
900-
k_scale=k_scale,
901-
v_scale=v_scale,
899+
k_scale=layer._k_scale,
900+
v_scale=layer._v_scale,
902901
window_left=window_left)
903902

904903
if prefill_output is None and decode_output is not None:

vllm/attention/backends/hpu_attn.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,13 +152,12 @@ def __init__(
152152

153153
def forward(
154154
self,
155+
layer: torch.nn.Module,
155156
query: torch.Tensor,
156157
key: torch.Tensor,
157158
value: torch.Tensor,
158159
kv_cache: torch.Tensor,
159160
attn_metadata: HPUAttentionMetadata,
160-
k_scale: float = 1.0,
161-
v_scale: float = 1.0,
162161
output: Optional[torch.Tensor] = None,
163162
) -> torch.Tensor:
164163
"""Forward pass with xFormers and PagedAttention.

vllm/attention/backends/ipex_attn.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -171,13 +171,12 @@ def split_kv_cache(
171171

172172
def forward(
173173
self,
174+
layer: torch.nn.Module,
174175
query: torch.Tensor,
175176
key: torch.Tensor,
176177
value: torch.Tensor,
177178
kv_cache: torch.Tensor,
178179
attn_metadata: IpexAttnMetadata, # type: ignore
179-
k_scale: float = 1.0,
180-
v_scale: float = 1.0,
181180
output: Optional[torch.Tensor] = None,
182181
) -> torch.Tensor:
183182
"""Forward pass with IPEX varlen_attention and PagedAttention.
@@ -193,7 +192,7 @@ def forward(
193192
Returns:
194193
shape = [num_tokens, num_heads * head_size]
195194
"""
196-
assert k_scale == 1.0 and v_scale == 1.0
195+
assert layer._k_scale == 1.0 and layer._v_scale == 1.0
197196
num_tokens, hidden_size = query.shape
198197
# Reshape the query, key, and value tensors.
199198
query = query.view(-1, self.num_heads, self.head_size)
@@ -210,8 +209,8 @@ def forward(
210209
value_cache,
211210
attn_metadata.slot_mapping.flatten(),
212211
self.kv_cache_dtype,
213-
k_scale,
214-
v_scale,
212+
layer._k_scale,
213+
layer._v_scale,
215214
)
216215

217216
if attn_metadata.is_prompt:
@@ -296,8 +295,8 @@ def forward(
296295
max_seq_len,
297296
self.alibi_slopes,
298297
self.kv_cache_dtype,
299-
k_scale,
300-
v_scale,
298+
layer._k_scale,
299+
layer._v_scale,
301300
)
302301
else:
303302
# Run PagedAttention V2.
@@ -329,8 +328,8 @@ def forward(
329328
max_seq_len,
330329
self.alibi_slopes,
331330
self.kv_cache_dtype,
332-
k_scale,
333-
v_scale,
331+
layer._k_scale,
332+
layer._v_scale,
334333
)
335334

336335
# Reshape the output tensor.

vllm/attention/backends/pallas.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,13 +150,12 @@ def __init__(
150150

151151
def forward(
152152
self,
153+
layer: torch.nn.Module,
153154
query: torch.Tensor,
154155
key: torch.Tensor,
155156
value: torch.Tensor,
156157
kv_cache: Tuple[torch.Tensor, torch.Tensor],
157158
attn_metadata: PallasMetadata,
158-
k_scale: float = 1.0,
159-
v_scale: float = 1.0,
160159
output: Optional[torch.Tensor] = None,
161160
) -> torch.Tensor:
162161
"""Forward pass with Pallas attention.
@@ -173,7 +172,7 @@ def forward(
173172
Returns:
174173
shape = [batch_size, seq_len, num_heads * head_size]
175174
"""
176-
assert k_scale == 1.0 and v_scale == 1.0
175+
assert layer._k_scale == 1.0 and layer._v_scale == 1.0
177176
batch_size, seq_len, hidden_size = query.shape
178177
query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
179178
key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size)

vllm/attention/backends/rocm_flash_attn.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -414,13 +414,12 @@ def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
414414

415415
def forward(
416416
self,
417+
layer: torch.nn.Module,
417418
query: torch.Tensor,
418419
key: torch.Tensor,
419420
value: torch.Tensor,
420421
kv_cache: torch.Tensor,
421422
attn_metadata: ROCmFlashAttentionMetadata,
422-
k_scale: float = 1.0,
423-
v_scale: float = 1.0,
424423
output: Optional[torch.Tensor] = None,
425424
) -> torch.Tensor:
426425
"""Forward pass with FlashAttention and PagedAttention.
@@ -458,8 +457,8 @@ def forward(
458457
value_cache,
459458
attn_metadata.slot_mapping,
460459
self.kv_cache_dtype,
461-
k_scale,
462-
v_scale,
460+
layer._k_scale,
461+
layer._v_scale,
463462
)
464463

465464
num_prefill_tokens = attn_metadata.num_prefill_tokens
@@ -567,8 +566,8 @@ def forward(
567566
prefill_meta.max_query_len,
568567
self.alibi_slopes,
569568
self.sliding_window[0],
570-
k_scale,
571-
v_scale,
569+
layer._k_scale,
570+
layer._v_scale,
572571
)
573572

574573
if decode_meta := attn_metadata.decode_metadata:
@@ -613,8 +612,8 @@ def forward(
613612
max_seq_len,
614613
self.alibi_slopes,
615614
self.kv_cache_dtype,
616-
k_scale,
617-
v_scale,
615+
layer._k_scale,
616+
layer._v_scale,
618617
)
619618
else:
620619
output[num_prefill_tokens:] = PagedAttention.forward_decode(
@@ -628,8 +627,8 @@ def forward(
628627
self.num_kv_heads,
629628
self.scale,
630629
self.alibi_slopes,
631-
k_scale,
632-
v_scale,
630+
layer._k_scale,
631+
layer._v_scale,
633632
)
634633

635634
# Reshape the output tensor.

vllm/attention/backends/torch_sdpa.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -429,13 +429,12 @@ def __init__(
429429

430430
def forward(
431431
self,
432+
layer: torch.nn.Module,
432433
query: torch.Tensor,
433434
key: torch.Tensor,
434435
value: torch.Tensor,
435436
kv_cache: torch.Tensor,
436437
attn_metadata: TorchSDPAMetadata, # type: ignore
437-
k_scale: float = 1.0,
438-
v_scale: float = 1.0,
439438
output: Optional[torch.Tensor] = None,
440439
) -> torch.Tensor:
441440
"""Forward pass with torch SDPA and PagedAttention.
@@ -451,7 +450,7 @@ def forward(
451450
Returns:
452451
shape = [num_tokens, num_heads * head_size]
453452
"""
454-
assert k_scale == 1.0 and v_scale == 1.0
453+
assert layer._k_scale == 1.0 and layer._v_scale == 1.0
455454
attn_type = self.attn_type
456455
if (attn_type == AttentionType.ENCODER
457456
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
@@ -493,11 +492,9 @@ def forward(
493492
# Update self-attention KV cache (prefill/decode)
494493
updated_slot_mapping = attn_metadata.slot_mapping
495494

496-
PagedAttention.write_to_paged_cache(key, value, key_cache,
497-
value_cache,
498-
updated_slot_mapping,
499-
self.kv_cache_dtype,
500-
k_scale, v_scale)
495+
PagedAttention.write_to_paged_cache(
496+
key, value, key_cache, value_cache, updated_slot_mapping,
497+
self.kv_cache_dtype, layer._k_scale, layer._v_scale)
501498

502499
if attn_type != AttentionType.ENCODER:
503500
# Decoder self-attention supports chunked prefill.
@@ -571,8 +568,8 @@ def forward(
571568
self.num_kv_heads,
572569
self.scale,
573570
self.alibi_slopes,
574-
k_scale,
575-
v_scale,
571+
layer._k_scale,
572+
layer._v_scale,
576573
)
577574

578575
# Reshape the output tensor.

vllm/attention/backends/xformers.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -412,13 +412,12 @@ def __init__(
412412

413413
def forward(
414414
self,
415+
layer: torch.nn.Module,
415416
query: torch.Tensor,
416417
key: Optional[torch.Tensor],
417418
value: Optional[torch.Tensor],
418419
kv_cache: torch.Tensor,
419420
attn_metadata: "XFormersMetadata",
420-
k_scale: float = 1.0,
421-
v_scale: float = 1.0,
422421
output: Optional[torch.Tensor] = None,
423422
) -> torch.Tensor:
424423
"""Forward pass with xFormers and PagedAttention.
@@ -524,11 +523,9 @@ def forward(
524523
# If kv_cache is not provided, the new key and value tensors are
525524
# not cached. This happens during the initial memory
526525
# profiling run.
527-
PagedAttention.write_to_paged_cache(key, value, key_cache,
528-
value_cache,
529-
updated_slot_mapping,
530-
self.kv_cache_dtype,
531-
k_scale, v_scale)
526+
PagedAttention.write_to_paged_cache(
527+
key, value, key_cache, value_cache, updated_slot_mapping,
528+
self.kv_cache_dtype, layer._k_scale, layer._v_scale)
532529
(num_prefill_query_tokens, num_prefill_kv_tokens,
533530
num_decode_query_tokens) = \
534531
get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type)
@@ -580,8 +577,8 @@ def forward(
580577
prefill_meta.max_query_len,
581578
self.alibi_slopes,
582579
self.sliding_window,
583-
k_scale,
584-
v_scale,
580+
layer._k_scale,
581+
layer._v_scale,
585582
)
586583
assert output[:num_prefill_query_tokens].shape == out.shape
587584
output[:num_prefill_query_tokens] = out
@@ -607,8 +604,8 @@ def forward(
607604
self.num_kv_heads,
608605
self.scale,
609606
self.alibi_slopes,
610-
k_scale,
611-
v_scale,
607+
layer._k_scale,
608+
layer._v_scale,
612609
)
613610

614611
# Reshape the output tensor.

0 commit comments

Comments
 (0)