Skip to content

Commit 92159d7

Browse files
ywang96Isotr0py
andcommitted
[Misc][Bugfix] FA3 support to ViT MHA layer (vllm-project#12435)
Signed-off-by: Roger Wang <[email protected]> Signed-off-by: Isotr0py <[email protected]> Co-authored-by: Isotr0py <[email protected]> Signed-off-by: Isotr0py <[email protected]>
1 parent cd3e0e0 commit 92159d7

File tree

1 file changed

+22
-3
lines changed

1 file changed

+22
-3
lines changed

vllm/attention/layer.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -251,9 +251,28 @@ def forward(
251251
_Backend.FLASH_ATTN,
252252
_Backend.FLASH_ATTN_VLLM_V1,
253253
}:
254-
from vllm.vllm_flash_attn import flash_attn_func
255-
256-
out = flash_attn_func(query, key, value, softmax_scale=self.scale)
254+
from vllm.vllm_flash_attn import flash_attn_varlen_func
255+
256+
cu_seqlens_q = torch.arange(0, (bsz + 1) * q_len,
257+
step=q_len,
258+
dtype=torch.int32,
259+
device=query.device)
260+
cu_seqlens_k = torch.arange(0, (bsz + 1) * kv_len,
261+
step=kv_len,
262+
dtype=torch.int32,
263+
device=key.device)
264+
265+
out = flash_attn_varlen_func(
266+
query.flatten(0, 1),
267+
key.flatten(0, 1),
268+
value.flatten(0, 1),
269+
cu_seqlens_q=cu_seqlens_q,
270+
cu_seqlens_k=cu_seqlens_k,
271+
max_seqlen_q=q_len,
272+
max_seqlen_k=kv_len,
273+
softmax_scale=self.scale,
274+
)
275+
out = out.reshape(bsz, q_len, -1)
257276
elif self.attn_backend == _Backend.XFORMERS:
258277
from xformers import ops as xops
259278

0 commit comments

Comments
 (0)