Skip to content

Commit cfbb8c9

Browse files
authored
[TPU][V1] MHA Pallas backend (#15288)
Signed-off-by: NickLucche <[email protected]>
1 parent baec0d4 commit cfbb8c9

File tree

2 files changed

+117
-2
lines changed

2 files changed

+117
-2
lines changed

tests/v1/tpu/test_mha_attn.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""
3+
Test:
4+
5+
* Tests for MultiHeadAttention layer
6+
"""
7+
8+
import pytest
9+
import torch
10+
import torch_xla
11+
import torch_xla.core
12+
import torch_xla.core.xla_model
13+
14+
from vllm import envs
15+
from vllm.attention.layer import MultiHeadAttention
16+
from vllm.attention.selector import _cached_get_attn_backend
17+
from vllm.platforms import current_platform
18+
19+
if not envs.VLLM_USE_V1:
20+
pytest.skip(
21+
"Skipping V1 tests. Rerun with `VLLM_USE_V1=1` to test.",
22+
allow_module_level=True,
23+
)
24+
25+
26+
@pytest.fixture(autouse=True)
27+
def clear_cache():
28+
"""Clear lru cache to ensure each test case runs without caching.
29+
"""
30+
_cached_get_attn_backend.cache_clear()
31+
32+
33+
def ref_attention(
34+
query: torch.Tensor,
35+
key: torch.Tensor,
36+
value: torch.Tensor,
37+
scale: float,
38+
) -> torch.Tensor:
39+
"""
40+
Native implementation of scaled dot product attention without mask:
41+
- query, key, value: [batch_size, seq_len, num_heads, head_size]
42+
- attn_mask: [batch_size, seq_len, seq_len]
43+
"""
44+
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
45+
attn_weights = scale * torch.matmul(query, key.transpose(2, 3))
46+
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
47+
out = torch.matmul(attn_weights, value).transpose(1, 2)
48+
return out
49+
50+
51+
BATCH_SIZES = [1, 16]
52+
SEQ_LENS = [1]
53+
NUM_HEADS = [1, 16]
54+
NUM_KV_HEADS = [1]
55+
HEAD_SIZES = [64, 80]
56+
57+
58+
@pytest.mark.skipif(not current_platform.is_tpu(),
59+
reason="This test needs a TPU")
60+
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
61+
@pytest.mark.parametrize("seq_len", SEQ_LENS)
62+
@pytest.mark.parametrize("num_heads", NUM_HEADS)
63+
@pytest.mark.parametrize("num_kv_heads", NUM_KV_HEADS)
64+
@pytest.mark.parametrize("head_size", HEAD_SIZES)
65+
@pytest.mark.parametrize("device", [torch_xla.core.xla_model.xla_device()])
66+
def test_mha_attn_forward(
67+
batch_size: int,
68+
seq_len: int,
69+
num_heads: int,
70+
num_kv_heads: int,
71+
head_size: int,
72+
device: str,
73+
):
74+
current_platform.seed_everything(0)
75+
# These are expected to be f32
76+
q = torch.randn(batch_size, seq_len, num_heads * head_size, device=device)
77+
k = torch.randn(batch_size,
78+
seq_len,
79+
num_kv_heads * head_size,
80+
device=device)
81+
v = torch.randn(batch_size,
82+
seq_len,
83+
num_kv_heads * head_size,
84+
device=device)
85+
scale = 1.0 / head_size**0.5
86+
attn = MultiHeadAttention(num_heads,
87+
head_size,
88+
scale=scale,
89+
num_kv_heads=num_kv_heads)
90+
output = attn(q, k, v)
91+
92+
assert num_heads % num_kv_heads == 0
93+
num_queries_per_kv = num_heads // num_kv_heads
94+
95+
q = q.reshape(batch_size, seq_len, num_heads, head_size)
96+
k = k.reshape(batch_size, seq_len, num_kv_heads, head_size)
97+
v = v.reshape(batch_size, seq_len, num_kv_heads, head_size)
98+
if num_queries_per_kv > 1:
99+
k = torch.repeat_interleave(k, num_queries_per_kv, dim=2)
100+
v = torch.repeat_interleave(v, num_queries_per_kv, dim=2)
101+
102+
ref_output = ref_attention(
103+
q,
104+
k,
105+
v,
106+
scale=scale,
107+
).reshape(batch_size, seq_len, num_heads * head_size)
108+
# torch_xla flash_attn kernel is less accurate but much faster
109+
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-3)

vllm/attention/layer.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -281,8 +281,7 @@ def __init__(
281281
backend = _Backend.XFORMERS
282282

283283
self.attn_backend = backend if backend in {
284-
_Backend.TORCH_SDPA,
285-
_Backend.XFORMERS,
284+
_Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.PALLAS_VLLM_V1
286285
} else _Backend.TORCH_SDPA
287286

288287
def forward(
@@ -320,6 +319,13 @@ def forward(
320319
value,
321320
scale=self.scale)
322321
out = out.transpose(1, 2)
322+
elif self.attn_backend == _Backend.PALLAS_VLLM_V1:
323+
query, key, value = (x.transpose(1, 2)
324+
for x in (query, key, value))
325+
from torch_xla.experimental.custom_kernel import flash_attention
326+
out = flash_attention(query, key, value, sm_scale=self.scale)
327+
out = out.transpose(1, 2)
328+
323329
return out.reshape(bsz, q_len, -1)
324330

325331

0 commit comments

Comments
 (0)