13
13
import torch
14
14
15
15
from tests .kernels .utils import *
16
- from vllm .attention import (Attention , AttentionBackend , AttentionMetadata ,
17
- AttentionType )
16
+ from vllm .attention import Attention , AttentionMetadata , AttentionType
18
17
from vllm .attention .backends .utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP
19
18
from vllm .attention .selector import (_Backend , _cached_get_attn_backend ,
20
19
global_force_attn_backend_context_manager )
@@ -64,6 +63,7 @@ class TestPoint(NamedTuple):
64
63
max_dec_seq_len : int
65
64
max_enc_seq_len : int
66
65
num_blocks : int
66
+ attn_type : AttentionType
67
67
68
68
69
69
class TestResources (NamedTuple ):
@@ -96,7 +96,6 @@ class TestResources(NamedTuple):
96
96
'''
97
97
98
98
scale : float
99
- attn_backend : AttentionBackend
100
99
attn : Attention
101
100
kv_cache : torch .Tensor
102
101
@@ -129,16 +128,17 @@ class that Attention will automatically select when it is constructed.
129
128
'''
130
129
131
130
scale = float (1.0 / (test_pt .head_size ** 0.5 ))
132
- attn_backend = make_backend (test_pt .backend_name )
133
131
attn = Attention (
134
132
test_pt .num_heads ,
135
133
test_pt .head_size ,
136
134
scale = scale ,
135
+ prefix = f"{ test_pt .attn_type } " ,
136
+ attn_type = test_pt .attn_type ,
137
137
)
138
138
if test_pt .num_blocks is None or test_pt .num_heads is None :
139
139
# Caller does not require a KV cache
140
140
return TestResources (
141
- scale , attn_backend , attn ,
141
+ scale , attn ,
142
142
torch .tensor ([], dtype = torch .float32 , device = CUDA_DEVICE ))
143
143
144
144
# Construct KV cache
@@ -148,7 +148,7 @@ class that Attention will automatically select when it is constructed.
148
148
test_pt .block_size ,
149
149
device = CUDA_DEVICE ,
150
150
backend = test_pt .backend_name )
151
- return TestResources (scale , attn_backend , attn , kv_cache )
151
+ return TestResources (scale , attn , kv_cache )
152
152
153
153
154
154
def _encoder_attn_setup (
@@ -193,6 +193,7 @@ def _encoder_attn_setup(
193
193
_ ,
194
194
max_q_seq_len ,
195
195
_ ,
196
+ _ ,
196
197
) = test_pt
197
198
198
199
scale = test_rsrcs .scale
@@ -301,6 +302,7 @@ def _decoder_attn_setup(
301
302
max_q_seq_len ,
302
303
_ ,
303
304
_ ,
305
+ _ ,
304
306
) = test_pt
305
307
306
308
scale = test_rsrcs .scale
@@ -488,6 +490,7 @@ def _enc_dec_cross_attn_setup_reuses_query(
488
490
max_decoder_seq_len ,
489
491
max_encoder_seq_len ,
490
492
_ ,
493
+ _ ,
491
494
) = test_pt
492
495
493
496
scale = test_rsrcs .scale
@@ -622,7 +625,6 @@ def _run_encoder_attention_test(
622
625
& attn_metadata
623
626
'''
624
627
assert attn_metadata .num_decode_tokens == 0
625
- attn_type = AttentionType .ENCODER
626
628
packed_qkv = encoder_test_params .packed_qkvo .packed_qkv
627
629
assert packed_qkv is not None
628
630
with set_forward_context (attn_metadata , vllm_config ):
@@ -635,14 +637,11 @@ def _run_encoder_attention_test(
635
637
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
636
638
reshaped_query = packed_qkv .query .view (
637
639
- 1 , test_pt .num_heads * test_pt .head_size )
638
- return attn .forward (reshaped_query ,
639
- packed_qkv .key ,
640
- packed_qkv .value ,
641
- torch .tensor ([],
642
- dtype = torch .float32 ,
643
- device = packed_qkv .query .device ),
644
- attn_metadata ,
645
- attn_type = attn_type )
640
+ return attn .forward (
641
+ reshaped_query , packed_qkv .key , packed_qkv .value ,
642
+ torch .tensor ([],
643
+ dtype = torch .float32 ,
644
+ device = packed_qkv .query .device ), attn_metadata )
646
645
647
646
648
647
def _run_decoder_self_attention_test (
@@ -675,7 +674,6 @@ def _run_decoder_self_attention_test(
675
674
* Attention.forward() applied to packed_{query,key,value}, kv_cache
676
675
& attn_metadata
677
676
'''
678
- attn_type = AttentionType .DECODER
679
677
attn = test_rsrcs .attn
680
678
kv_cache = test_rsrcs .kv_cache
681
679
packed_qkv = decoder_test_params .packed_qkvo .packed_qkv
@@ -690,12 +688,8 @@ def _run_decoder_self_attention_test(
690
688
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
691
689
reshaped_query = packed_qkv .query .view (
692
690
- 1 , test_pt .num_heads * test_pt .head_size )
693
- return attn .forward (reshaped_query ,
694
- packed_qkv .key ,
695
- packed_qkv .value ,
696
- kv_cache ,
697
- attn_metadata ,
698
- attn_type = attn_type )
691
+ return attn .forward (reshaped_query , packed_qkv .key , packed_qkv .value ,
692
+ kv_cache , attn_metadata )
699
693
700
694
701
695
def _run_encoder_decoder_cross_attention_test (
@@ -742,7 +736,6 @@ def _run_encoder_decoder_cross_attention_test(
742
736
'''
743
737
assert decoder_test_params .packed_qkvo .packed_qkv is not None
744
738
745
- attn_type = AttentionType .ENCODER_DECODER
746
739
attn = test_rsrcs .attn
747
740
kv_cache = test_rsrcs .kv_cache
748
741
if cross_test_params is None :
@@ -762,12 +755,8 @@ def _run_encoder_decoder_cross_attention_test(
762
755
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
763
756
reshaped_query = decoder_test_params .packed_qkvo .packed_qkv .query .view (
764
757
- 1 , test_pt .num_heads * test_pt .head_size )
765
- return attn .forward (reshaped_query ,
766
- key ,
767
- value ,
768
- kv_cache ,
769
- attn_metadata ,
770
- attn_type = attn_type )
758
+ return attn .forward (reshaped_query , key , value , kv_cache ,
759
+ attn_metadata )
771
760
772
761
773
762
@pytest .fixture (autouse = True )
@@ -839,7 +828,7 @@ def test_encoder_only(
839
828
# is not part of this test
840
829
test_pt = TestPoint (num_heads , head_size , attn_backend .name ,
841
830
batch_size , block_size , max_dec_seq_len ,
842
- max_enc_seq_len , 4096 )
831
+ max_enc_seq_len , 4096 , AttentionType . ENCODER )
843
832
844
833
# Attention scale factor, attention backend instance, attention wrapper
845
834
# instance, KV cache init
@@ -855,7 +844,7 @@ def test_encoder_only(
855
844
# Shared prefill metadata structure
856
845
857
846
prephase_attn_metadata : AttentionMetadata = make_test_metadata (
858
- test_rsrcs . attn_backend ,
847
+ attn_backend ,
859
848
True ,
860
849
None ,
861
850
decoder_test_params = None ,
@@ -961,20 +950,29 @@ def test_e2e_enc_dec_attn(
961
950
# Note: KV cache size of 4096 is arbitrary & chosen intentionally
962
951
# to be more than necessary, since exceeding the kv cache size
963
952
# is not part of this test
964
- test_pt = TestPoint (num_heads , head_size , attn_backend .name ,
965
- batch_size , block_size , max_dec_seq_len ,
966
- max_enc_seq_len , 4096 )
953
+ enc_test_pt = TestPoint (num_heads , head_size , attn_backend .name ,
954
+ batch_size , block_size , max_dec_seq_len ,
955
+ max_enc_seq_len , 4096 , AttentionType .ENCODER )
956
+ enc_dec_test_pt = TestPoint (num_heads , head_size , attn_backend .name ,
957
+ batch_size , block_size , max_dec_seq_len ,
958
+ max_enc_seq_len , 4096 ,
959
+ AttentionType .ENCODER_DECODER )
960
+ dec_test_pt = TestPoint (num_heads , head_size , attn_backend .name ,
961
+ batch_size , block_size , max_dec_seq_len ,
962
+ max_enc_seq_len , 4096 , AttentionType .DECODER )
967
963
968
964
# Attention scale factor, attention backend instance, attention wrapper
969
965
# instance, KV cache init
970
966
vllm_config = VllmConfig ()
971
967
with set_current_vllm_config (vllm_config ):
972
- test_rsrcs = _make_test_resources (test_pt )
968
+ enc_test_rsrcs = _make_test_resources (enc_test_pt )
969
+ enc_dec_test_rsrcs = _make_test_resources (enc_dec_test_pt )
970
+ dec_test_rsrcs = _make_test_resources (dec_test_pt )
973
971
974
972
# Construct encoder attention test params (only used
975
973
# during prefill)
976
974
977
- enc_test_params = _encoder_attn_setup (test_pt , test_rsrcs )
975
+ enc_test_params = _encoder_attn_setup (enc_test_pt , enc_test_rsrcs )
978
976
979
977
# Construct Decoder self-attention prefill-phase & decode-phase
980
978
# test params, including query/key/value tensors, decoder self-attention
@@ -987,7 +985,7 @@ def test_e2e_enc_dec_attn(
987
985
prephase_dec_test_params ,
988
986
decphase_dec_test_params ,
989
987
cross_block_base_addr ,
990
- ) = _decoder_attn_setup (test_pt , test_rsrcs )
988
+ ) = _decoder_attn_setup (dec_test_pt , dec_test_rsrcs )
991
989
992
990
# Construct encoder/decoder cross-attention prefill-phase
993
991
# & decode-phase test params, including key/value tensors,
@@ -1000,14 +998,14 @@ def test_e2e_enc_dec_attn(
1000
998
dec_qkv ,
1001
999
enc_test_params ,
1002
1000
prephase_dec_test_params ,
1003
- test_pt ,
1004
- test_rsrcs ,
1001
+ enc_dec_test_pt ,
1002
+ enc_dec_test_rsrcs ,
1005
1003
block_base_addr = cross_block_base_addr )
1006
1004
1007
1005
# Shared prefill metadata structure
1008
1006
assert prephase_dec_test_params .packed_qkvo .packed_qkv is not None
1009
1007
prephase_attn_metadata : AttentionMetadata = make_test_metadata (
1010
- test_rsrcs . attn_backend ,
1008
+ attn_backend ,
1011
1009
True ,
1012
1010
prephase_dec_test_params .packed_qkvo .packed_qkv .q_seq_lens ,
1013
1011
decoder_test_params = prephase_dec_test_params ,
@@ -1017,10 +1015,10 @@ def test_e2e_enc_dec_attn(
1017
1015
1018
1016
# PREFILL: encoder attention
1019
1017
1020
- enc_pckd_act_out = _run_encoder_attention_test (test_rsrcs .attn ,
1018
+ enc_pckd_act_out = _run_encoder_attention_test (enc_test_rsrcs .attn ,
1021
1019
enc_test_params ,
1022
1020
prephase_attn_metadata ,
1023
- test_pt = test_pt ,
1021
+ test_pt = enc_test_pt ,
1024
1022
vllm_config = vllm_config )
1025
1023
1026
1024
# - Is encoder attention result correct?
@@ -1030,10 +1028,10 @@ def test_e2e_enc_dec_attn(
1030
1028
# PREFILL: decoder self-attention test
1031
1029
1032
1030
prephase_dec_pckd_act_out = _run_decoder_self_attention_test (
1033
- test_rsrcs ,
1031
+ dec_test_rsrcs ,
1034
1032
prephase_dec_test_params ,
1035
1033
prephase_attn_metadata ,
1036
- test_pt = test_pt ,
1034
+ test_pt = dec_test_pt ,
1037
1035
vllm_config = vllm_config )
1038
1036
1039
1037
# - Is prefill decoder self-attention correct?
@@ -1044,11 +1042,11 @@ def test_e2e_enc_dec_attn(
1044
1042
# PREFILL: encoder/decoder cross-attention test
1045
1043
1046
1044
prephase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test (
1047
- test_rsrcs ,
1045
+ enc_dec_test_rsrcs ,
1048
1046
prephase_dec_test_params ,
1049
1047
prephase_cross_test_params ,
1050
1048
prephase_attn_metadata ,
1051
- test_pt = test_pt ,
1049
+ test_pt = enc_dec_test_pt ,
1052
1050
vllm_config = vllm_config )
1053
1051
1054
1052
# - Is prefill encoder/decoder cross-attention correct?
@@ -1059,7 +1057,7 @@ def test_e2e_enc_dec_attn(
1059
1057
# DECODE: build decode-phase attention metadata
1060
1058
1061
1059
decphase_attn_metadata : AttentionMetadata = make_test_metadata (
1062
- test_rsrcs . attn_backend ,
1060
+ attn_backend ,
1063
1061
False ,
1064
1062
dec_qkv .q_seq_lens ,
1065
1063
decoder_test_params = decphase_dec_test_params ,
@@ -1070,10 +1068,10 @@ def test_e2e_enc_dec_attn(
1070
1068
# DECODE: decoder self-attention test
1071
1069
1072
1070
decphase_dec_pckd_act_out = _run_decoder_self_attention_test (
1073
- test_rsrcs ,
1071
+ dec_test_rsrcs ,
1074
1072
decphase_dec_test_params ,
1075
1073
decphase_attn_metadata ,
1076
- test_pt = test_pt ,
1074
+ test_pt = dec_test_pt ,
1077
1075
vllm_config = vllm_config )
1078
1076
1079
1077
# - Is decode-phase decoder self-attention correct?
@@ -1084,11 +1082,11 @@ def test_e2e_enc_dec_attn(
1084
1082
# DECODE: encoder/decoder cross-attention test
1085
1083
1086
1084
decphase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test (
1087
- test_rsrcs ,
1085
+ enc_dec_test_rsrcs ,
1088
1086
decphase_dec_test_params ,
1089
1087
None ,
1090
1088
decphase_attn_metadata ,
1091
- test_pt = test_pt ,
1089
+ test_pt = enc_dec_test_pt ,
1092
1090
vllm_config = vllm_config )
1093
1091
1094
1092
# - Is decode-phase encoder/decoder cross-attention correct?
0 commit comments