Skip to content

Commit 1be4877

Browse files
committed
Revert "[Misc] Add FA2 support to ViT MHA layer (#12355)"
This reverts commit f1fc051.
1 parent 324960a commit 1be4877

File tree

2 files changed

+5
-146
lines changed

2 files changed

+5
-146
lines changed

tests/kernels/test_mha_attn.py

Lines changed: 0 additions & 126 deletions
This file was deleted.

vllm/attention/layer.py

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -210,22 +210,18 @@ def __init__(
210210
self.scale = scale
211211
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
212212

213-
assert self.num_heads % self.num_kv_heads == 0
214-
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
215-
216213
dtype = torch.get_default_dtype()
217214
attn_backend = get_attn_backend(head_size,
218215
dtype,
219216
kv_cache_dtype=None,
220217
block_size=16,
221218
is_attention_free=False)
222219
backend = backend_name_to_enum(attn_backend.get_name())
220+
if backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}:
221+
backend = _Backend.XFORMERS
223222

224223
self.attn_backend = backend if backend in {
225-
_Backend.TORCH_SDPA,
226-
_Backend.XFORMERS,
227-
_Backend.FLASH_ATTN,
228-
_Backend.FLASH_ATTN_VLLM_V1,
224+
_Backend.TORCH_SDPA, _Backend.XFORMERS
229225
} else _Backend.TORCH_SDPA
230226

231227
def forward(
@@ -235,26 +231,15 @@ def forward(
235231
value: torch.Tensor,
236232
) -> torch.Tensor:
237233
"""Input shape: batch_size x seq_len x hidden_size"""
234+
# TODO(Isotr0py): Use existing backend implementations and support FA2
238235
bsz, q_len, _ = query.size()
239236
kv_len = key.size(1)
240237

241238
query = query.view(bsz, q_len, self.num_heads, self.head_size)
242239
key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
243240
value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)
244241

245-
if (num_repeat := self.num_queries_per_kv) > 1:
246-
# Handle MQA and GQA
247-
key = torch.repeat_interleave(key, num_repeat, dim=2)
248-
value = torch.repeat_interleave(value, num_repeat, dim=2)
249-
250-
if self.attn_backend in {
251-
_Backend.FLASH_ATTN,
252-
_Backend.FLASH_ATTN_VLLM_V1,
253-
}:
254-
from vllm.vllm_flash_attn import flash_attn_func
255-
256-
out = flash_attn_func(query, key, value, softmax_scale=self.scale)
257-
elif self.attn_backend == _Backend.XFORMERS:
242+
if self.attn_backend == _Backend.XFORMERS:
258243
from xformers import ops as xops
259244

260245
out = xops.memory_efficient_attention_forward(query,

0 commit comments

Comments
 (0)