Skip to content

Commit 72f4880

Browse files
authored
[Bugfix/CI] Fix broken kernels/test_mha.py (#12450)
1 parent aa2cd2c commit 72f4880

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

tests/kernels/test_mha_attn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def clear_cache():
2626
@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"])
2727
def test_mha_attn_platform(device: str):
2828
"""
29-
Test that the attention selector between different platform and device.
29+
Test the attention selector between different platform and device.
3030
"""
3131
torch.set_default_dtype(torch.float16)
3232

@@ -41,7 +41,7 @@ def test_mha_attn_platform(device: str):
4141
else:
4242
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
4343
attn = MultiHeadAttention(16, 64, scale=1)
44-
assert attn.attn_backend == _Backend.FLASH_ATTN
44+
assert attn.attn_backend == _Backend.XFORMERS
4545

4646
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
4747
attn = MultiHeadAttention(16, 72, scale=1)

vllm/attention/layer.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,9 @@ def __init__(
210210
self.scale = scale
211211
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
212212

213+
assert self.num_heads % self.num_kv_heads == 0
214+
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
215+
213216
dtype = torch.get_default_dtype()
214217
attn_backend = get_attn_backend(head_size,
215218
dtype,
@@ -240,6 +243,11 @@ def forward(
240243
key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
241244
value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)
242245

246+
if (num_repeat := self.num_queries_per_kv) > 1:
247+
# Handle MQA and GQA
248+
key = torch.repeat_interleave(key, num_repeat, dim=2)
249+
value = torch.repeat_interleave(value, num_repeat, dim=2)
250+
243251
if self.attn_backend == _Backend.XFORMERS:
244252
from xformers import ops as xops
245253

0 commit comments

Comments
 (0)