Skip to content

Commit f8a96ba

Browse files
mgoinDamonFool
authored andcommitted
[V1][Bugfix] Standardize quantized kv cache rejection for attention backends (vllm-project#14221)
Signed-off-by: mgoin <[email protected]>
1 parent 4e9f173 commit f8a96ba

File tree

11 files changed

+59
-20
lines changed

11 files changed

+59
-20
lines changed

vllm/attention/backends/abstract.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,3 +294,7 @@ def forward(
294294
output: Optional[torch.Tensor] = None,
295295
) -> torch.Tensor:
296296
raise NotImplementedError
297+
298+
299+
def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
300+
return kv_cache_dtype != "auto"

vllm/attention/backends/flash_attn.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,15 @@
88
import torch
99

1010
from vllm import _custom_ops as ops
11+
# yapf conflicts with isort for this block
12+
# yapf: disable
1113
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
1214
AttentionLayer,
1315
AttentionMetadata,
1416
AttentionMetadataBuilder,
15-
AttentionType)
17+
AttentionType,
18+
is_quantized_kv_cache)
19+
# yapf: enable
1620
from vllm.attention.backends.utils import (
1721
PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping,
1822
compute_slot_mapping_start_idx, get_flash_attn_version,
@@ -626,6 +630,9 @@ def __init__(
626630
self.sliding_window = ((sliding_window - 1,
627631
0) if sliding_window is not None else (-1, -1))
628632
self.kv_cache_dtype = kv_cache_dtype
633+
if is_quantized_kv_cache(self.kv_cache_dtype):
634+
raise NotImplementedError(
635+
"FlashAttention with FP8 KV cache not yet supported")
629636
if logits_soft_cap is None:
630637
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
631638
logits_soft_cap = 0

vllm/attention/backends/flashmla.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66

77
import torch
88

9-
from vllm.attention.backends.abstract import AttentionType
9+
from vllm.attention.backends.abstract import (AttentionType,
10+
is_quantized_kv_cache)
1011
from vllm.attention.backends.mla.common import (MLACommonBackend,
1112
MLACommonImpl,
1213
MLACommonMetadata,
@@ -207,6 +208,10 @@ def __init__(
207208
"are not implemented for "
208209
"FlashMLAImpl")
209210

211+
if is_quantized_kv_cache(self.kv_cache_dtype):
212+
raise NotImplementedError(
213+
"FlashMLA with FP8 KV cache not yet supported")
214+
210215
def _forward_decode(
211216
self,
212217
q_nope: torch.Tensor,
@@ -215,8 +220,6 @@ def _forward_decode(
215220
attn_metadata: FlashMLAMetadata,
216221
) -> torch.Tensor:
217222
assert kv_c_and_k_pe_cache.numel() > 0
218-
if self.kv_cache_dtype.startswith("fp8"):
219-
raise NotImplementedError("FP8 FlashMLA not yet supported")
220223

221224
decode_meta = attn_metadata.decode_metadata
222225
assert decode_meta is not None

vllm/attention/backends/hpu_attn.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515

1616
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
1717
AttentionLayer,
18-
AttentionMetadata, AttentionType)
18+
AttentionMetadata, AttentionType,
19+
is_quantized_kv_cache)
1920
from vllm.attention.backends.utils import CommonAttentionState
2021
from vllm.attention.ops.hpu_paged_attn import (HPUPagedAttention,
2122
HPUPagedAttentionMetadata)
@@ -158,6 +159,10 @@ def __init__(
158159
"are not implemented for "
159160
"HPUAttentionImpl")
160161

162+
if is_quantized_kv_cache(self.kv_cache_dtype):
163+
raise NotImplementedError(
164+
"HPUAttention with FP8 KV cache not yet supported")
165+
161166
def forward(
162167
self,
163168
layer: AttentionLayer,

vllm/attention/backends/ipex_attn.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
from vllm._ipex_ops import ipex_ops
1010
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
1111
AttentionLayer,
12-
AttentionMetadata, AttentionType)
12+
AttentionMetadata, AttentionType,
13+
is_quantized_kv_cache)
1314
from vllm.attention.backends.utils import CommonAttentionState
1415
from vllm.attention.ops.paged_attn import (PagedAttention,
1516
PagedAttentionMetadata)
@@ -145,7 +146,7 @@ def __init__(
145146
raise ValueError(
146147
f"Head size {head_size} is not supported by PagedAttention. "
147148
f"Supported head sizes are: {supported_head_sizes}.")
148-
if kv_cache_dtype != "auto":
149+
if is_quantized_kv_cache(kv_cache_dtype):
149150
raise NotImplementedError(
150151
"IPEX backend does not support FP8 KV cache. "
151152
"Please use xFormers backend instead.")

vllm/attention/backends/pallas.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88

99
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
1010
AttentionLayer,
11-
AttentionMetadata, AttentionType)
11+
AttentionMetadata, AttentionType,
12+
is_quantized_kv_cache)
1213
from vllm.attention.backends.utils import CommonAttentionState
1314

1415

@@ -119,7 +120,7 @@ def __init__(
119120
raise NotImplementedError("Alibi slopes is not supported.")
120121
if sliding_window is not None:
121122
raise NotImplementedError("Sliding window is not supported.")
122-
if kv_cache_dtype != "auto":
123+
if is_quantized_kv_cache(kv_cache_dtype):
123124
raise NotImplementedError("FP8 KV cache dtype is not supported.")
124125
if blocksparse_params is not None:
125126
raise NotImplementedError("Blocksparse is not supported.")

vllm/attention/backends/torch_sdpa.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,15 @@
77
import torch
88
from torch.nn.functional import scaled_dot_product_attention
99

10+
# yapf conflicts with isort for this block
11+
# yapf: disable
1012
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
1113
AttentionLayer,
1214
AttentionMetadata,
1315
AttentionMetadataBuilder,
14-
AttentionType)
16+
AttentionType,
17+
is_quantized_kv_cache)
18+
# yapf: enable
1519
from vllm.attention.backends.utils import CommonAttentionState
1620
from vllm.attention.ops.ipex_attn import PagedAttention
1721
from vllm.attention.ops.paged_attn import PagedAttentionMetadata
@@ -427,7 +431,7 @@ def __init__(
427431
raise ValueError(
428432
f"Head size {head_size} is not supported by PagedAttention. "
429433
f"Supported head sizes are: {supported_head_sizes}.")
430-
if kv_cache_dtype != "auto":
434+
if is_quantized_kv_cache(kv_cache_dtype):
431435
raise NotImplementedError(
432436
"Torch SDPA backend does not support FP8 KV cache. "
433437
"Please use xFormers backend instead.")

vllm/attention/backends/triton_mla.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44

55
import torch
66

7-
from vllm.attention.backends.abstract import AttentionType
7+
from vllm.attention.backends.abstract import (AttentionType,
8+
is_quantized_kv_cache)
89
from vllm.attention.backends.mla.common import (MLACommonBackend,
910
MLACommonImpl,
1011
MLACommonMetadata)
@@ -58,6 +59,10 @@ def __init__(
5859
"are not implemented for "
5960
"TritonMLAImpl")
6061

62+
if is_quantized_kv_cache(self.kv_cache_dtype):
63+
raise NotImplementedError(
64+
"TritonMLA with FP8 KV cache not yet supported")
65+
6166
def _forward_decode(
6267
self,
6368
q_nope: torch.Tensor,
@@ -66,8 +71,6 @@ def _forward_decode(
6671
attn_metadata: MLACommonMetadata,
6772
) -> torch.Tensor:
6873
assert kv_c_and_k_pe_cache.numel() > 0
69-
if self.kv_cache_dtype.startswith("fp8"):
70-
raise NotImplementedError("FP8 Triton MLA not yet supported")
7174

7275
decode_meta = attn_metadata.decode_metadata
7376
assert decode_meta is not None

vllm/v1/attention/backends/flash_attn.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
import torch
88

99
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
10-
AttentionMetadata, AttentionType)
10+
AttentionMetadata, AttentionType,
11+
is_quantized_kv_cache)
1112
from vllm.attention.backends.utils import get_flash_attn_version
1213
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
1314
from vllm.logger import init_logger
@@ -180,6 +181,9 @@ def __init__(
180181
else:
181182
self.sliding_window = (sliding_window - 1, 0)
182183
self.kv_cache_dtype = kv_cache_dtype
184+
if is_quantized_kv_cache(self.kv_cache_dtype):
185+
raise NotImplementedError(
186+
"FlashAttention V1 with FP8 KV cache not yet supported")
183187
if logits_soft_cap is None:
184188
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
185189
logits_soft_cap = 0

vllm/v1/attention/backends/mla/flashmla.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55

66
import torch
77

8-
from vllm.attention.backends.abstract import AttentionType
8+
from vllm.attention.backends.abstract import (AttentionType,
9+
is_quantized_kv_cache)
910
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
1011
get_mla_metadata,
1112
is_flashmla_supported)
@@ -115,6 +116,10 @@ def __init__(
115116
"are not implemented for "
116117
"FlashMLAImpl")
117118

119+
if is_quantized_kv_cache(self.kv_cache_dtype):
120+
raise NotImplementedError(
121+
"FlashMLA V1 with FP8 KV cache not yet supported")
122+
118123
def _forward_decode(
119124
self,
120125
q_nope: torch.Tensor,
@@ -125,9 +130,6 @@ def _forward_decode(
125130
assert kv_c_and_k_pe_cache.numel() > 0
126131
assert attn_metadata.decode is not None
127132

128-
if self.kv_cache_dtype.startswith("fp8"):
129-
raise NotImplementedError("FP8 FlashMLA not yet supported")
130-
131133
q = torch.cat([q_nope, q_pe], dim=-1)\
132134
.unsqueeze(1) # Add seqlen dim of 1 (decode)
133135

vllm/v1/attention/backends/mla/triton_mla.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44

55
import torch
66

7-
from vllm.attention.backends.abstract import AttentionType
7+
from vllm.attention.backends.abstract import (AttentionType,
8+
is_quantized_kv_cache)
89
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
910
from vllm.logger import init_logger
1011
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
@@ -61,6 +62,10 @@ def __init__(
6162
"are not implemented for "
6263
"TritonMLAImpl")
6364

65+
if is_quantized_kv_cache(self.kv_cache_dtype):
66+
raise NotImplementedError(
67+
"TritonMLA V1 with FP8 KV cache not yet supported")
68+
6469
def _forward_decode(
6570
self,
6671
q_nope: torch.Tensor,

0 commit comments

Comments
 (0)