File tree 1 file changed +22
-3
lines changed
1 file changed +22
-3
lines changed Original file line number Diff line number Diff line change @@ -251,9 +251,28 @@ def forward(
251
251
_Backend .FLASH_ATTN ,
252
252
_Backend .FLASH_ATTN_VLLM_V1 ,
253
253
}:
254
- from vllm .vllm_flash_attn import flash_attn_func
255
-
256
- out = flash_attn_func (query , key , value , softmax_scale = self .scale )
254
+ from vllm .vllm_flash_attn import flash_attn_varlen_func
255
+
256
+ cu_seqlens_q = torch .arange (0 , (bsz + 1 ) * q_len ,
257
+ step = q_len ,
258
+ dtype = torch .int32 ,
259
+ device = query .device )
260
+ cu_seqlens_k = torch .arange (0 , (bsz + 1 ) * kv_len ,
261
+ step = kv_len ,
262
+ dtype = torch .int32 ,
263
+ device = key .device )
264
+
265
+ out = flash_attn_varlen_func (
266
+ query .flatten (0 , 1 ),
267
+ key .flatten (0 , 1 ),
268
+ value .flatten (0 , 1 ),
269
+ cu_seqlens_q = cu_seqlens_q ,
270
+ cu_seqlens_k = cu_seqlens_k ,
271
+ max_seqlen_q = q_len ,
272
+ max_seqlen_k = kv_len ,
273
+ softmax_scale = self .scale ,
274
+ )
275
+ out = out .reshape (bsz , q_len , - 1 )
257
276
elif self .attn_backend == _Backend .XFORMERS :
258
277
from xformers import ops as xops
259
278
You can’t perform that action at this time.
0 commit comments