@@ -210,22 +210,19 @@ def __init__(
210
210
self .scale = scale
211
211
self .num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
212
212
213
- assert self .num_heads % self .num_kv_heads == 0
214
- self .num_queries_per_kv = self .num_heads // self .num_kv_heads
215
-
216
213
dtype = torch .get_default_dtype ()
217
214
attn_backend = get_attn_backend (head_size ,
218
215
dtype ,
219
216
kv_cache_dtype = None ,
220
217
block_size = 16 ,
221
218
is_attention_free = False )
222
219
backend = backend_name_to_enum (attn_backend .get_name ())
220
+ if backend in {_Backend .FLASH_ATTN , _Backend .FLASH_ATTN_VLLM_V1 }:
221
+ backend = _Backend .XFORMERS
223
222
224
223
self .attn_backend = backend if backend in {
225
224
_Backend .TORCH_SDPA ,
226
225
_Backend .XFORMERS ,
227
- _Backend .FLASH_ATTN ,
228
- _Backend .FLASH_ATTN_VLLM_V1 ,
229
226
} else _Backend .TORCH_SDPA
230
227
231
228
def forward (
@@ -235,45 +232,15 @@ def forward(
235
232
value : torch .Tensor ,
236
233
) -> torch .Tensor :
237
234
"""Input shape: batch_size x seq_len x hidden_size"""
235
+ # TODO(Isotr0py): Use existing backend implementations and support FA3
238
236
bsz , q_len , _ = query .size ()
239
237
kv_len = key .size (1 )
240
238
241
239
query = query .view (bsz , q_len , self .num_heads , self .head_size )
242
240
key = key .view (bsz , kv_len , self .num_kv_heads , self .head_size )
243
241
value = value .view (bsz , kv_len , self .num_kv_heads , self .head_size )
244
242
245
- if (num_repeat := self .num_queries_per_kv ) > 1 :
246
- # Handle MQA and GQA
247
- key = torch .repeat_interleave (key , num_repeat , dim = 2 )
248
- value = torch .repeat_interleave (value , num_repeat , dim = 2 )
249
-
250
- if self .attn_backend in {
251
- _Backend .FLASH_ATTN ,
252
- _Backend .FLASH_ATTN_VLLM_V1 ,
253
- }:
254
- from vllm .vllm_flash_attn import flash_attn_varlen_func
255
-
256
- cu_seqlens_q = torch .arange (0 , (bsz + 1 ) * q_len ,
257
- step = q_len ,
258
- dtype = torch .int32 ,
259
- device = query .device )
260
- cu_seqlens_k = torch .arange (0 , (bsz + 1 ) * kv_len ,
261
- step = kv_len ,
262
- dtype = torch .int32 ,
263
- device = key .device )
264
-
265
- out = flash_attn_varlen_func (
266
- query .flatten (0 , 1 ),
267
- key .flatten (0 , 1 ),
268
- value .flatten (0 , 1 ),
269
- cu_seqlens_q = cu_seqlens_q ,
270
- cu_seqlens_k = cu_seqlens_k ,
271
- max_seqlen_q = q_len ,
272
- max_seqlen_k = kv_len ,
273
- softmax_scale = self .scale ,
274
- )
275
- out = out .reshape (bsz , q_len , - 1 )
276
- elif self .attn_backend == _Backend .XFORMERS :
243
+ if self .attn_backend == _Backend .XFORMERS :
277
244
from xformers import ops as xops
278
245
279
246
out = xops .memory_efficient_attention_forward (query ,
0 commit comments