Skip to content

Commit bec24a2

Browse files
ywang96rasmith
authored andcommitted
1 parent e22bfc1 commit bec24a2

File tree

1 file changed

+4
-37
lines changed

1 file changed

+4
-37
lines changed

vllm/attention/layer.py

Lines changed: 4 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -210,22 +210,19 @@ def __init__(
210210
self.scale = scale
211211
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
212212

213-
assert self.num_heads % self.num_kv_heads == 0
214-
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
215-
216213
dtype = torch.get_default_dtype()
217214
attn_backend = get_attn_backend(head_size,
218215
dtype,
219216
kv_cache_dtype=None,
220217
block_size=16,
221218
is_attention_free=False)
222219
backend = backend_name_to_enum(attn_backend.get_name())
220+
if backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}:
221+
backend = _Backend.XFORMERS
223222

224223
self.attn_backend = backend if backend in {
225224
_Backend.TORCH_SDPA,
226225
_Backend.XFORMERS,
227-
_Backend.FLASH_ATTN,
228-
_Backend.FLASH_ATTN_VLLM_V1,
229226
} else _Backend.TORCH_SDPA
230227

231228
def forward(
@@ -235,45 +232,15 @@ def forward(
235232
value: torch.Tensor,
236233
) -> torch.Tensor:
237234
"""Input shape: batch_size x seq_len x hidden_size"""
235+
# TODO(Isotr0py): Use existing backend implementations and support FA3
238236
bsz, q_len, _ = query.size()
239237
kv_len = key.size(1)
240238

241239
query = query.view(bsz, q_len, self.num_heads, self.head_size)
242240
key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
243241
value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)
244242

245-
if (num_repeat := self.num_queries_per_kv) > 1:
246-
# Handle MQA and GQA
247-
key = torch.repeat_interleave(key, num_repeat, dim=2)
248-
value = torch.repeat_interleave(value, num_repeat, dim=2)
249-
250-
if self.attn_backend in {
251-
_Backend.FLASH_ATTN,
252-
_Backend.FLASH_ATTN_VLLM_V1,
253-
}:
254-
from vllm.vllm_flash_attn import flash_attn_varlen_func
255-
256-
cu_seqlens_q = torch.arange(0, (bsz + 1) * q_len,
257-
step=q_len,
258-
dtype=torch.int32,
259-
device=query.device)
260-
cu_seqlens_k = torch.arange(0, (bsz + 1) * kv_len,
261-
step=kv_len,
262-
dtype=torch.int32,
263-
device=key.device)
264-
265-
out = flash_attn_varlen_func(
266-
query.flatten(0, 1),
267-
key.flatten(0, 1),
268-
value.flatten(0, 1),
269-
cu_seqlens_q=cu_seqlens_q,
270-
cu_seqlens_k=cu_seqlens_k,
271-
max_seqlen_q=q_len,
272-
max_seqlen_k=kv_len,
273-
softmax_scale=self.scale,
274-
)
275-
out = out.reshape(bsz, q_len, -1)
276-
elif self.attn_backend == _Backend.XFORMERS:
243+
if self.attn_backend == _Backend.XFORMERS:
277244
from xformers import ops as xops
278245

279246
out = xops.memory_efficient_attention_forward(query,

0 commit comments

Comments
 (0)