Skip to content

Commit f1fc051

Browse files
authored
[Misc] Add FA2 support to ViT MHA layer (#12355)
Signed-off-by: Isotr0py <[email protected]>
1 parent bf21481 commit f1fc051

File tree

2 files changed

+146
-5
lines changed

2 files changed

+146
-5
lines changed

tests/kernels/test_mha_attn.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
"""
2+
Test:
3+
4+
* Tests for MultiHeadAttention layer
5+
"""
6+
from unittest.mock import patch
7+
8+
import pytest
9+
import torch
10+
11+
from vllm.attention.layer import MultiHeadAttention
12+
from vllm.attention.selector import _Backend, _cached_get_attn_backend
13+
from vllm.platforms import current_platform
14+
from vllm.platforms.cpu import CpuPlatform
15+
from vllm.platforms.cuda import CudaPlatform
16+
from vllm.platforms.rocm import RocmPlatform
17+
18+
19+
@pytest.fixture(autouse=True)
20+
def clear_cache():
21+
"""Clear lru cache to ensure each test case runs without caching.
22+
"""
23+
_cached_get_attn_backend.cache_clear()
24+
25+
26+
@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"])
27+
def test_mha_attn_platform(device: str):
28+
"""
29+
Test that the attention selector between different platform and device.
30+
"""
31+
torch.set_default_dtype(torch.float16)
32+
33+
if device == "cpu":
34+
with patch("vllm.attention.selector.current_platform", CpuPlatform()):
35+
attn = MultiHeadAttention(16, 64, scale=1)
36+
assert attn.attn_backend == _Backend.TORCH_SDPA
37+
elif device == "hip":
38+
with patch("vllm.attention.selector.current_platform", RocmPlatform()):
39+
attn = MultiHeadAttention(16, 64, scale=1)
40+
assert attn.attn_backend == _Backend.TORCH_SDPA
41+
else:
42+
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
43+
attn = MultiHeadAttention(16, 64, scale=1)
44+
assert attn.attn_backend == _Backend.FLASH_ATTN
45+
46+
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
47+
attn = MultiHeadAttention(16, 72, scale=1)
48+
assert attn.attn_backend == _Backend.XFORMERS
49+
50+
51+
def ref_attention(
52+
query: torch.Tensor,
53+
key: torch.Tensor,
54+
value: torch.Tensor,
55+
scale: float,
56+
) -> torch.Tensor:
57+
"""
58+
Native implementation of scaled dot product attention without mask:
59+
- query, key, value: [batch_size, seq_len, num_heads, head_size]
60+
- attn_mask: [batch_size, seq_len, seq_len]
61+
"""
62+
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
63+
attn_weights = scale * torch.matmul(query, key.transpose(2, 3))
64+
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
65+
out = torch.matmul(attn_weights, value).transpose(1, 2)
66+
return out
67+
68+
69+
BATCH_SIZES = [1, 16]
70+
SEQ_LENS = [1]
71+
NUM_HEADS = [1, 16]
72+
NUM_KV_HEADS = [1]
73+
HEAD_SIZES = [64, 80]
74+
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
75+
DTYPES = [
76+
torch.half, torch.bfloat16, torch.float
77+
] if not current_platform.is_rocm() else [torch.half, torch.bfloat16]
78+
CUDA_DEVICES = ["cuda"]
79+
80+
81+
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
82+
@pytest.mark.parametrize("seq_len", SEQ_LENS)
83+
@pytest.mark.parametrize("num_heads", NUM_HEADS)
84+
@pytest.mark.parametrize("num_kv_heads", NUM_KV_HEADS)
85+
@pytest.mark.parametrize("head_size", HEAD_SIZES)
86+
@pytest.mark.parametrize("dtype", DTYPES)
87+
@pytest.mark.parametrize("device", CUDA_DEVICES)
88+
def test_mha_attn_forward(
89+
batch_size: int,
90+
seq_len: int,
91+
num_heads: int,
92+
num_kv_heads: int,
93+
head_size: int,
94+
dtype: torch.dtype,
95+
device: str,
96+
):
97+
current_platform.seed_everything(0)
98+
torch.set_default_device(device)
99+
torch.set_default_dtype(dtype)
100+
101+
q = torch.randn(batch_size, seq_len, num_heads * head_size)
102+
k = torch.randn(batch_size, seq_len, num_kv_heads * head_size)
103+
v = torch.randn(batch_size, seq_len, num_kv_heads * head_size)
104+
scale = 1.0 / head_size**0.5
105+
attn = MultiHeadAttention(num_heads,
106+
head_size,
107+
scale=scale,
108+
num_kv_heads=num_kv_heads)
109+
output = attn(q, k, v)
110+
111+
assert num_heads % num_kv_heads == 0
112+
num_queries_per_kv = num_heads // num_kv_heads
113+
q = q.reshape(batch_size, seq_len, num_heads, head_size)
114+
k = k.reshape(batch_size, seq_len, num_kv_heads, head_size)
115+
v = v.reshape(batch_size, seq_len, num_kv_heads, head_size)
116+
if num_queries_per_kv > 1:
117+
k = torch.repeat_interleave(k, num_queries_per_kv, dim=2)
118+
v = torch.repeat_interleave(v, num_queries_per_kv, dim=2)
119+
120+
ref_output = ref_attention(
121+
q,
122+
k,
123+
v,
124+
scale=scale,
125+
).reshape(batch_size, seq_len, num_heads * head_size)
126+
torch.testing.assert_close(output, ref_output)

vllm/attention/layer.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -210,18 +210,22 @@ def __init__(
210210
self.scale = scale
211211
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
212212

213+
assert self.num_heads % self.num_kv_heads == 0
214+
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
215+
213216
dtype = torch.get_default_dtype()
214217
attn_backend = get_attn_backend(head_size,
215218
dtype,
216219
kv_cache_dtype=None,
217220
block_size=16,
218221
is_attention_free=False)
219222
backend = backend_name_to_enum(attn_backend.get_name())
220-
if backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}:
221-
backend = _Backend.XFORMERS
222223

223224
self.attn_backend = backend if backend in {
224-
_Backend.TORCH_SDPA, _Backend.XFORMERS
225+
_Backend.TORCH_SDPA,
226+
_Backend.XFORMERS,
227+
_Backend.FLASH_ATTN,
228+
_Backend.FLASH_ATTN_VLLM_V1,
225229
} else _Backend.TORCH_SDPA
226230

227231
def forward(
@@ -231,15 +235,26 @@ def forward(
231235
value: torch.Tensor,
232236
) -> torch.Tensor:
233237
"""Input shape: batch_size x seq_len x hidden_size"""
234-
# TODO(Isotr0py): Use existing backend implementations and support FA2
235238
bsz, q_len, _ = query.size()
236239
kv_len = key.size(1)
237240

238241
query = query.view(bsz, q_len, self.num_heads, self.head_size)
239242
key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
240243
value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)
241244

242-
if self.attn_backend == _Backend.XFORMERS:
245+
if (num_repeat := self.num_queries_per_kv) > 1:
246+
# Handle MQA and GQA
247+
key = torch.repeat_interleave(key, num_repeat, dim=2)
248+
value = torch.repeat_interleave(value, num_repeat, dim=2)
249+
250+
if self.attn_backend in {
251+
_Backend.FLASH_ATTN,
252+
_Backend.FLASH_ATTN_VLLM_V1,
253+
}:
254+
from vllm.vllm_flash_attn import flash_attn_func
255+
256+
out = flash_attn_func(query, key, value, softmax_scale=self.scale)
257+
elif self.attn_backend == _Backend.XFORMERS:
243258
from xformers import ops as xops
244259

245260
out = xops.memory_efficient_attention_forward(query,

0 commit comments

Comments
 (0)