@@ -210,22 +210,18 @@ def __init__(
210
210
self .scale = scale
211
211
self .num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
212
212
213
- assert self .num_heads % self .num_kv_heads == 0
214
- self .num_queries_per_kv = self .num_heads // self .num_kv_heads
215
-
216
213
dtype = torch .get_default_dtype ()
217
214
attn_backend = get_attn_backend (head_size ,
218
215
dtype ,
219
216
kv_cache_dtype = None ,
220
217
block_size = 16 ,
221
218
is_attention_free = False )
222
219
backend = backend_name_to_enum (attn_backend .get_name ())
220
+ if backend in {_Backend .FLASH_ATTN , _Backend .FLASH_ATTN_VLLM_V1 }:
221
+ backend = _Backend .XFORMERS
223
222
224
223
self .attn_backend = backend if backend in {
225
- _Backend .TORCH_SDPA ,
226
- _Backend .XFORMERS ,
227
- _Backend .FLASH_ATTN ,
228
- _Backend .FLASH_ATTN_VLLM_V1 ,
224
+ _Backend .TORCH_SDPA , _Backend .XFORMERS
229
225
} else _Backend .TORCH_SDPA
230
226
231
227
def forward (
@@ -235,26 +231,15 @@ def forward(
235
231
value : torch .Tensor ,
236
232
) -> torch .Tensor :
237
233
"""Input shape: batch_size x seq_len x hidden_size"""
234
+ # TODO(Isotr0py): Use existing backend implementations and support FA2
238
235
bsz , q_len , _ = query .size ()
239
236
kv_len = key .size (1 )
240
237
241
238
query = query .view (bsz , q_len , self .num_heads , self .head_size )
242
239
key = key .view (bsz , kv_len , self .num_kv_heads , self .head_size )
243
240
value = value .view (bsz , kv_len , self .num_kv_heads , self .head_size )
244
241
245
- if (num_repeat := self .num_queries_per_kv ) > 1 :
246
- # Handle MQA and GQA
247
- key = torch .repeat_interleave (key , num_repeat , dim = 2 )
248
- value = torch .repeat_interleave (value , num_repeat , dim = 2 )
249
-
250
- if self .attn_backend in {
251
- _Backend .FLASH_ATTN ,
252
- _Backend .FLASH_ATTN_VLLM_V1 ,
253
- }:
254
- from vllm .vllm_flash_attn import flash_attn_func
255
-
256
- out = flash_attn_func (query , key , value , softmax_scale = self .scale )
257
- elif self .attn_backend == _Backend .XFORMERS :
242
+ if self .attn_backend == _Backend .XFORMERS :
258
243
from xformers import ops as xops
259
244
260
245
out = xops .memory_efficient_attention_forward (query ,
0 commit comments