Skip to content

Commit e20c92b

Browse files
authored
[Kernel] Move attn_type to Attention.__init__() (#11690)
Signed-off-by: Chen Zhang <[email protected]>
1 parent 32c9eff commit e20c92b

18 files changed

+159
-201
lines changed

tests/kernels/test_encoder_decoder_attn.py

Lines changed: 49 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@
1313
import torch
1414

1515
from tests.kernels.utils import *
16-
from vllm.attention import (Attention, AttentionBackend, AttentionMetadata,
17-
AttentionType)
16+
from vllm.attention import Attention, AttentionMetadata, AttentionType
1817
from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP
1918
from vllm.attention.selector import (_Backend, _cached_get_attn_backend,
2019
global_force_attn_backend_context_manager)
@@ -64,6 +63,7 @@ class TestPoint(NamedTuple):
6463
max_dec_seq_len: int
6564
max_enc_seq_len: int
6665
num_blocks: int
66+
attn_type: AttentionType
6767

6868

6969
class TestResources(NamedTuple):
@@ -96,7 +96,6 @@ class TestResources(NamedTuple):
9696
'''
9797

9898
scale: float
99-
attn_backend: AttentionBackend
10099
attn: Attention
101100
kv_cache: torch.Tensor
102101

@@ -129,16 +128,17 @@ class that Attention will automatically select when it is constructed.
129128
'''
130129

131130
scale = float(1.0 / (test_pt.head_size**0.5))
132-
attn_backend = make_backend(test_pt.backend_name)
133131
attn = Attention(
134132
test_pt.num_heads,
135133
test_pt.head_size,
136134
scale=scale,
135+
prefix=f"{test_pt.attn_type}",
136+
attn_type=test_pt.attn_type,
137137
)
138138
if test_pt.num_blocks is None or test_pt.num_heads is None:
139139
# Caller does not require a KV cache
140140
return TestResources(
141-
scale, attn_backend, attn,
141+
scale, attn,
142142
torch.tensor([], dtype=torch.float32, device=CUDA_DEVICE))
143143

144144
# Construct KV cache
@@ -148,7 +148,7 @@ class that Attention will automatically select when it is constructed.
148148
test_pt.block_size,
149149
device=CUDA_DEVICE,
150150
backend=test_pt.backend_name)
151-
return TestResources(scale, attn_backend, attn, kv_cache)
151+
return TestResources(scale, attn, kv_cache)
152152

153153

154154
def _encoder_attn_setup(
@@ -193,6 +193,7 @@ def _encoder_attn_setup(
193193
_,
194194
max_q_seq_len,
195195
_,
196+
_,
196197
) = test_pt
197198

198199
scale = test_rsrcs.scale
@@ -301,6 +302,7 @@ def _decoder_attn_setup(
301302
max_q_seq_len,
302303
_,
303304
_,
305+
_,
304306
) = test_pt
305307

306308
scale = test_rsrcs.scale
@@ -488,6 +490,7 @@ def _enc_dec_cross_attn_setup_reuses_query(
488490
max_decoder_seq_len,
489491
max_encoder_seq_len,
490492
_,
493+
_,
491494
) = test_pt
492495

493496
scale = test_rsrcs.scale
@@ -622,7 +625,6 @@ def _run_encoder_attention_test(
622625
& attn_metadata
623626
'''
624627
assert attn_metadata.num_decode_tokens == 0
625-
attn_type = AttentionType.ENCODER
626628
packed_qkv = encoder_test_params.packed_qkvo.packed_qkv
627629
assert packed_qkv is not None
628630
with set_forward_context(attn_metadata, vllm_config):
@@ -635,14 +637,11 @@ def _run_encoder_attention_test(
635637
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
636638
reshaped_query = packed_qkv.query.view(
637639
-1, test_pt.num_heads * test_pt.head_size)
638-
return attn.forward(reshaped_query,
639-
packed_qkv.key,
640-
packed_qkv.value,
641-
torch.tensor([],
642-
dtype=torch.float32,
643-
device=packed_qkv.query.device),
644-
attn_metadata,
645-
attn_type=attn_type)
640+
return attn.forward(
641+
reshaped_query, packed_qkv.key, packed_qkv.value,
642+
torch.tensor([],
643+
dtype=torch.float32,
644+
device=packed_qkv.query.device), attn_metadata)
646645

647646

648647
def _run_decoder_self_attention_test(
@@ -675,7 +674,6 @@ def _run_decoder_self_attention_test(
675674
* Attention.forward() applied to packed_{query,key,value}, kv_cache
676675
& attn_metadata
677676
'''
678-
attn_type = AttentionType.DECODER
679677
attn = test_rsrcs.attn
680678
kv_cache = test_rsrcs.kv_cache
681679
packed_qkv = decoder_test_params.packed_qkvo.packed_qkv
@@ -690,12 +688,8 @@ def _run_decoder_self_attention_test(
690688
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
691689
reshaped_query = packed_qkv.query.view(
692690
-1, test_pt.num_heads * test_pt.head_size)
693-
return attn.forward(reshaped_query,
694-
packed_qkv.key,
695-
packed_qkv.value,
696-
kv_cache,
697-
attn_metadata,
698-
attn_type=attn_type)
691+
return attn.forward(reshaped_query, packed_qkv.key, packed_qkv.value,
692+
kv_cache, attn_metadata)
699693

700694

701695
def _run_encoder_decoder_cross_attention_test(
@@ -742,7 +736,6 @@ def _run_encoder_decoder_cross_attention_test(
742736
'''
743737
assert decoder_test_params.packed_qkvo.packed_qkv is not None
744738

745-
attn_type = AttentionType.ENCODER_DECODER
746739
attn = test_rsrcs.attn
747740
kv_cache = test_rsrcs.kv_cache
748741
if cross_test_params is None:
@@ -762,12 +755,8 @@ def _run_encoder_decoder_cross_attention_test(
762755
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
763756
reshaped_query = decoder_test_params.packed_qkvo.packed_qkv.query.view(
764757
-1, test_pt.num_heads * test_pt.head_size)
765-
return attn.forward(reshaped_query,
766-
key,
767-
value,
768-
kv_cache,
769-
attn_metadata,
770-
attn_type=attn_type)
758+
return attn.forward(reshaped_query, key, value, kv_cache,
759+
attn_metadata)
771760

772761

773762
@pytest.fixture(autouse=True)
@@ -839,7 +828,7 @@ def test_encoder_only(
839828
# is not part of this test
840829
test_pt = TestPoint(num_heads, head_size, attn_backend.name,
841830
batch_size, block_size, max_dec_seq_len,
842-
max_enc_seq_len, 4096)
831+
max_enc_seq_len, 4096, AttentionType.ENCODER)
843832

844833
# Attention scale factor, attention backend instance, attention wrapper
845834
# instance, KV cache init
@@ -855,7 +844,7 @@ def test_encoder_only(
855844
# Shared prefill metadata structure
856845

857846
prephase_attn_metadata: AttentionMetadata = make_test_metadata(
858-
test_rsrcs.attn_backend,
847+
attn_backend,
859848
True,
860849
None,
861850
decoder_test_params=None,
@@ -961,20 +950,29 @@ def test_e2e_enc_dec_attn(
961950
# Note: KV cache size of 4096 is arbitrary & chosen intentionally
962951
# to be more than necessary, since exceeding the kv cache size
963952
# is not part of this test
964-
test_pt = TestPoint(num_heads, head_size, attn_backend.name,
965-
batch_size, block_size, max_dec_seq_len,
966-
max_enc_seq_len, 4096)
953+
enc_test_pt = TestPoint(num_heads, head_size, attn_backend.name,
954+
batch_size, block_size, max_dec_seq_len,
955+
max_enc_seq_len, 4096, AttentionType.ENCODER)
956+
enc_dec_test_pt = TestPoint(num_heads, head_size, attn_backend.name,
957+
batch_size, block_size, max_dec_seq_len,
958+
max_enc_seq_len, 4096,
959+
AttentionType.ENCODER_DECODER)
960+
dec_test_pt = TestPoint(num_heads, head_size, attn_backend.name,
961+
batch_size, block_size, max_dec_seq_len,
962+
max_enc_seq_len, 4096, AttentionType.DECODER)
967963

968964
# Attention scale factor, attention backend instance, attention wrapper
969965
# instance, KV cache init
970966
vllm_config = VllmConfig()
971967
with set_current_vllm_config(vllm_config):
972-
test_rsrcs = _make_test_resources(test_pt)
968+
enc_test_rsrcs = _make_test_resources(enc_test_pt)
969+
enc_dec_test_rsrcs = _make_test_resources(enc_dec_test_pt)
970+
dec_test_rsrcs = _make_test_resources(dec_test_pt)
973971

974972
# Construct encoder attention test params (only used
975973
# during prefill)
976974

977-
enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs)
975+
enc_test_params = _encoder_attn_setup(enc_test_pt, enc_test_rsrcs)
978976

979977
# Construct Decoder self-attention prefill-phase & decode-phase
980978
# test params, including query/key/value tensors, decoder self-attention
@@ -987,7 +985,7 @@ def test_e2e_enc_dec_attn(
987985
prephase_dec_test_params,
988986
decphase_dec_test_params,
989987
cross_block_base_addr,
990-
) = _decoder_attn_setup(test_pt, test_rsrcs)
988+
) = _decoder_attn_setup(dec_test_pt, dec_test_rsrcs)
991989

992990
# Construct encoder/decoder cross-attention prefill-phase
993991
# & decode-phase test params, including key/value tensors,
@@ -1000,14 +998,14 @@ def test_e2e_enc_dec_attn(
1000998
dec_qkv,
1001999
enc_test_params,
10021000
prephase_dec_test_params,
1003-
test_pt,
1004-
test_rsrcs,
1001+
enc_dec_test_pt,
1002+
enc_dec_test_rsrcs,
10051003
block_base_addr=cross_block_base_addr)
10061004

10071005
# Shared prefill metadata structure
10081006
assert prephase_dec_test_params.packed_qkvo.packed_qkv is not None
10091007
prephase_attn_metadata: AttentionMetadata = make_test_metadata(
1010-
test_rsrcs.attn_backend,
1008+
attn_backend,
10111009
True,
10121010
prephase_dec_test_params.packed_qkvo.packed_qkv.q_seq_lens,
10131011
decoder_test_params=prephase_dec_test_params,
@@ -1017,10 +1015,10 @@ def test_e2e_enc_dec_attn(
10171015

10181016
# PREFILL: encoder attention
10191017

1020-
enc_pckd_act_out = _run_encoder_attention_test(test_rsrcs.attn,
1018+
enc_pckd_act_out = _run_encoder_attention_test(enc_test_rsrcs.attn,
10211019
enc_test_params,
10221020
prephase_attn_metadata,
1023-
test_pt=test_pt,
1021+
test_pt=enc_test_pt,
10241022
vllm_config=vllm_config)
10251023

10261024
# - Is encoder attention result correct?
@@ -1030,10 +1028,10 @@ def test_e2e_enc_dec_attn(
10301028
# PREFILL: decoder self-attention test
10311029

10321030
prephase_dec_pckd_act_out = _run_decoder_self_attention_test(
1033-
test_rsrcs,
1031+
dec_test_rsrcs,
10341032
prephase_dec_test_params,
10351033
prephase_attn_metadata,
1036-
test_pt=test_pt,
1034+
test_pt=dec_test_pt,
10371035
vllm_config=vllm_config)
10381036

10391037
# - Is prefill decoder self-attention correct?
@@ -1044,11 +1042,11 @@ def test_e2e_enc_dec_attn(
10441042
# PREFILL: encoder/decoder cross-attention test
10451043

10461044
prephase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test(
1047-
test_rsrcs,
1045+
enc_dec_test_rsrcs,
10481046
prephase_dec_test_params,
10491047
prephase_cross_test_params,
10501048
prephase_attn_metadata,
1051-
test_pt=test_pt,
1049+
test_pt=enc_dec_test_pt,
10521050
vllm_config=vllm_config)
10531051

10541052
# - Is prefill encoder/decoder cross-attention correct?
@@ -1059,7 +1057,7 @@ def test_e2e_enc_dec_attn(
10591057
# DECODE: build decode-phase attention metadata
10601058

10611059
decphase_attn_metadata: AttentionMetadata = make_test_metadata(
1062-
test_rsrcs.attn_backend,
1060+
attn_backend,
10631061
False,
10641062
dec_qkv.q_seq_lens,
10651063
decoder_test_params=decphase_dec_test_params,
@@ -1070,10 +1068,10 @@ def test_e2e_enc_dec_attn(
10701068
# DECODE: decoder self-attention test
10711069

10721070
decphase_dec_pckd_act_out = _run_decoder_self_attention_test(
1073-
test_rsrcs,
1071+
dec_test_rsrcs,
10741072
decphase_dec_test_params,
10751073
decphase_attn_metadata,
1076-
test_pt=test_pt,
1074+
test_pt=dec_test_pt,
10771075
vllm_config=vllm_config)
10781076

10791077
# - Is decode-phase decoder self-attention correct?
@@ -1084,11 +1082,11 @@ def test_e2e_enc_dec_attn(
10841082
# DECODE: encoder/decoder cross-attention test
10851083

10861084
decphase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test(
1087-
test_rsrcs,
1085+
enc_dec_test_rsrcs,
10881086
decphase_dec_test_params,
10891087
None,
10901088
decphase_attn_metadata,
1091-
test_pt=test_pt,
1089+
test_pt=enc_dec_test_pt,
10921090
vllm_config=vllm_config)
10931091

10941092
# - Is decode-phase encoder/decoder cross-attention correct?

tests/kernels/utils.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
1515
from vllm.model_executor.layers.activation import SiluAndMul
16+
from vllm.platforms.interface import _Backend
1617
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL,
1718
STR_XFORMERS_ATTN_VAL, make_tensor_with_pad)
1819

@@ -790,7 +791,7 @@ def make_block_tables_slot_mapping(
790791

791792

792793
def make_test_metadata(
793-
attn_backend: AttentionBackend,
794+
attn_backend: _Backend,
794795
is_prompt: bool,
795796
seq_lens: Optional[List[int]],
796797
decoder_test_params: Optional[PhaseTestParameters],
@@ -815,7 +816,7 @@ def make_test_metadata(
815816
816817
Arguments:
817818
818-
* attn_backend: Backend for sourcing attention kernels
819+
* attn_backend_name: Backend for sourcing attention kernels
819820
* is_prompt: prefill if True, o/w decode
820821
* seq_lens: list of token counts for each sequence
821822
* decoder_test_params: decoder self-attention test params;
@@ -882,6 +883,8 @@ def make_test_metadata(
882883
# (kv_mmap)
883884
cross_kv_mmap = cross_test_params.kv_mmap
884885

886+
attn_backend_obj = make_backend(attn_backend.name)
887+
885888
if is_prompt:
886889
# Prefill-phase scenario
887890

@@ -902,8 +905,7 @@ def make_test_metadata(
902905
context_lens,
903906
encoder_seq_lens,
904907
device=device)
905-
906-
return attn_backend.make_metadata(
908+
return attn_backend_obj.make_metadata(
907909
num_prefills=num_prefills,
908910
slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping),
909911
multi_modal_placeholder_index_maps=None,
@@ -952,7 +954,7 @@ def make_test_metadata(
952954
encoder_seq_lens,
953955
device=device)
954956

955-
return attn_backend.make_metadata(
957+
return attn_backend_obj.make_metadata(
956958
num_prefills=num_prefills,
957959
slot_mapping=kv_mmap.slot_mapping,
958960
multi_modal_placeholder_index_maps=None,

vllm/attention/backends/abstract.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ def __init__(
233233
kv_cache_dtype: str = "auto",
234234
blocksparse_params: Optional[Dict[str, Any]] = None,
235235
logits_soft_cap: Optional[float] = None,
236+
attn_type: str = AttentionType.DECODER,
236237
) -> None:
237238
raise NotImplementedError
238239

@@ -246,7 +247,6 @@ def forward(
246247
attn_metadata: T,
247248
k_scale: float = 1.0,
248249
v_scale: float = 1.0,
249-
attn_type: str = AttentionType.DECODER,
250250
output: Optional[torch.Tensor] = None,
251251
) -> torch.Tensor:
252252
raise NotImplementedError

0 commit comments

Comments
 (0)