File tree 2 files changed +10
-2
lines changed 2 files changed +10
-2
lines changed Original file line number Diff line number Diff line change @@ -26,7 +26,7 @@ def clear_cache():
26
26
@pytest .mark .parametrize ("device" , ["cpu" , "hip" , "cuda" ])
27
27
def test_mha_attn_platform (device : str ):
28
28
"""
29
- Test that the attention selector between different platform and device.
29
+ Test the attention selector between different platform and device.
30
30
"""
31
31
torch .set_default_dtype (torch .float16 )
32
32
@@ -41,7 +41,7 @@ def test_mha_attn_platform(device: str):
41
41
else :
42
42
with patch ("vllm.attention.selector.current_platform" , CudaPlatform ()):
43
43
attn = MultiHeadAttention (16 , 64 , scale = 1 )
44
- assert attn .attn_backend == _Backend .FLASH_ATTN
44
+ assert attn .attn_backend == _Backend .XFORMERS
45
45
46
46
with patch ("vllm.attention.selector.current_platform" , CudaPlatform ()):
47
47
attn = MultiHeadAttention (16 , 72 , scale = 1 )
Original file line number Diff line number Diff line change @@ -210,6 +210,9 @@ def __init__(
210
210
self .scale = scale
211
211
self .num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
212
212
213
+ assert self .num_heads % self .num_kv_heads == 0
214
+ self .num_queries_per_kv = self .num_heads // self .num_kv_heads
215
+
213
216
dtype = torch .get_default_dtype ()
214
217
attn_backend = get_attn_backend (head_size ,
215
218
dtype ,
@@ -240,6 +243,11 @@ def forward(
240
243
key = key .view (bsz , kv_len , self .num_kv_heads , self .head_size )
241
244
value = value .view (bsz , kv_len , self .num_kv_heads , self .head_size )
242
245
246
+ if (num_repeat := self .num_queries_per_kv ) > 1 :
247
+ # Handle MQA and GQA
248
+ key = torch .repeat_interleave (key , num_repeat , dim = 2 )
249
+ value = torch .repeat_interleave (value , num_repeat , dim = 2 )
250
+
243
251
if self .attn_backend == _Backend .XFORMERS :
244
252
from xformers import ops as xops
245
253
You can’t perform that action at this time.
0 commit comments