From 0d5228d3c9ebbd0fc1435bd3b02b14d898f8f409 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Sat, 25 Jan 2025 18:25:33 -0800 Subject: [PATCH 1/2] update flash attn API Signed-off-by: Roger Wang --- vllm/attention/layer.py | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index a90bb4fbf5a..fc860af70f1 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -251,9 +251,27 @@ def forward( _Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1, }: - from vllm.vllm_flash_attn import flash_attn_func - - out = flash_attn_func(query, key, value, softmax_scale=self.scale) + from vllm.vllm_flash_attn import flash_attn_varlen_func + + cu_seqlens_q = torch.arange(0, (bsz + 1) * q_len, + step=q_len, + dtype=torch.int32, + device=query.device) + cu_seqlens_k = torch.arange(0, (bsz + 1) * kv_len, + step=kv_len, + dtype=torch.int32, + device=key.device) + + out = flash_attn_varlen_func( + query, + key, + value, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=q_len, + max_seqlen_k=kv_len, + softmax_scale=self.scale, + ) elif self.attn_backend == _Backend.XFORMERS: from xformers import ops as xops From 159f0f2b25d42eaac2447eb86610bbb8657ff38a Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sun, 26 Jan 2025 10:59:38 +0800 Subject: [PATCH 2/2] flatten qkv Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/attention/layer.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index fc860af70f1..db682b4ac63 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -263,15 +263,16 @@ def forward( device=key.device) out = flash_attn_varlen_func( - query, - key, - value, + query.flatten(0, 1), + key.flatten(0, 1), + value.flatten(0, 1), cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=q_len, max_seqlen_k=kv_len, softmax_scale=self.scale, ) + out = out.reshape(bsz, q_len, -1) elif self.attn_backend == _Backend.XFORMERS: from xformers import ops as xops