File tree Expand file tree Collapse file tree 11 files changed +59
-20
lines changed Expand file tree Collapse file tree 11 files changed +59
-20
lines changed Original file line number Diff line number Diff line change @@ -294,3 +294,7 @@ def forward(
294
294
output : Optional [torch .Tensor ] = None ,
295
295
) -> torch .Tensor :
296
296
raise NotImplementedError
297
+
298
+
299
+ def is_quantized_kv_cache (kv_cache_dtype : str ) -> bool :
300
+ return kv_cache_dtype != "auto"
Original file line number Diff line number Diff line change 8
8
import torch
9
9
10
10
from vllm import _custom_ops as ops
11
+ # yapf conflicts with isort for this block
12
+ # yapf: disable
11
13
from vllm .attention .backends .abstract import (AttentionBackend , AttentionImpl ,
12
14
AttentionLayer ,
13
15
AttentionMetadata ,
14
16
AttentionMetadataBuilder ,
15
- AttentionType )
17
+ AttentionType ,
18
+ is_quantized_kv_cache )
19
+ # yapf: enable
16
20
from vllm .attention .backends .utils import (
17
21
PAD_SLOT_ID , CommonAttentionState , compute_slot_mapping ,
18
22
compute_slot_mapping_start_idx , get_flash_attn_version ,
@@ -626,6 +630,9 @@ def __init__(
626
630
self .sliding_window = ((sliding_window - 1 ,
627
631
0 ) if sliding_window is not None else (- 1 , - 1 ))
628
632
self .kv_cache_dtype = kv_cache_dtype
633
+ if is_quantized_kv_cache (self .kv_cache_dtype ):
634
+ raise NotImplementedError (
635
+ "FlashAttention with FP8 KV cache not yet supported" )
629
636
if logits_soft_cap is None :
630
637
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
631
638
logits_soft_cap = 0
Original file line number Diff line number Diff line change 6
6
7
7
import torch
8
8
9
- from vllm .attention .backends .abstract import AttentionType
9
+ from vllm .attention .backends .abstract import (AttentionType ,
10
+ is_quantized_kv_cache )
10
11
from vllm .attention .backends .mla .common import (MLACommonBackend ,
11
12
MLACommonImpl ,
12
13
MLACommonMetadata ,
@@ -207,6 +208,10 @@ def __init__(
207
208
"are not implemented for "
208
209
"FlashMLAImpl" )
209
210
211
+ if is_quantized_kv_cache (self .kv_cache_dtype ):
212
+ raise NotImplementedError (
213
+ "FlashMLA with FP8 KV cache not yet supported" )
214
+
210
215
def _forward_decode (
211
216
self ,
212
217
q_nope : torch .Tensor ,
@@ -215,8 +220,6 @@ def _forward_decode(
215
220
attn_metadata : FlashMLAMetadata ,
216
221
) -> torch .Tensor :
217
222
assert kv_c_and_k_pe_cache .numel () > 0
218
- if self .kv_cache_dtype .startswith ("fp8" ):
219
- raise NotImplementedError ("FP8 FlashMLA not yet supported" )
220
223
221
224
decode_meta = attn_metadata .decode_metadata
222
225
assert decode_meta is not None
Original file line number Diff line number Diff line change 15
15
16
16
from vllm .attention .backends .abstract import (AttentionBackend , AttentionImpl ,
17
17
AttentionLayer ,
18
- AttentionMetadata , AttentionType )
18
+ AttentionMetadata , AttentionType ,
19
+ is_quantized_kv_cache )
19
20
from vllm .attention .backends .utils import CommonAttentionState
20
21
from vllm .attention .ops .hpu_paged_attn import (HPUPagedAttention ,
21
22
HPUPagedAttentionMetadata )
@@ -158,6 +159,10 @@ def __init__(
158
159
"are not implemented for "
159
160
"HPUAttentionImpl" )
160
161
162
+ if is_quantized_kv_cache (self .kv_cache_dtype ):
163
+ raise NotImplementedError (
164
+ "HPUAttention with FP8 KV cache not yet supported" )
165
+
161
166
def forward (
162
167
self ,
163
168
layer : AttentionLayer ,
Original file line number Diff line number Diff line change 9
9
from vllm ._ipex_ops import ipex_ops
10
10
from vllm .attention .backends .abstract import (AttentionBackend , AttentionImpl ,
11
11
AttentionLayer ,
12
- AttentionMetadata , AttentionType )
12
+ AttentionMetadata , AttentionType ,
13
+ is_quantized_kv_cache )
13
14
from vllm .attention .backends .utils import CommonAttentionState
14
15
from vllm .attention .ops .paged_attn import (PagedAttention ,
15
16
PagedAttentionMetadata )
@@ -145,7 +146,7 @@ def __init__(
145
146
raise ValueError (
146
147
f"Head size { head_size } is not supported by PagedAttention. "
147
148
f"Supported head sizes are: { supported_head_sizes } ." )
148
- if kv_cache_dtype != "auto" :
149
+ if is_quantized_kv_cache ( kv_cache_dtype ) :
149
150
raise NotImplementedError (
150
151
"IPEX backend does not support FP8 KV cache. "
151
152
"Please use xFormers backend instead." )
Original file line number Diff line number Diff line change 8
8
9
9
from vllm .attention .backends .abstract import (AttentionBackend , AttentionImpl ,
10
10
AttentionLayer ,
11
- AttentionMetadata , AttentionType )
11
+ AttentionMetadata , AttentionType ,
12
+ is_quantized_kv_cache )
12
13
from vllm .attention .backends .utils import CommonAttentionState
13
14
14
15
@@ -119,7 +120,7 @@ def __init__(
119
120
raise NotImplementedError ("Alibi slopes is not supported." )
120
121
if sliding_window is not None :
121
122
raise NotImplementedError ("Sliding window is not supported." )
122
- if kv_cache_dtype != "auto" :
123
+ if is_quantized_kv_cache ( kv_cache_dtype ) :
123
124
raise NotImplementedError ("FP8 KV cache dtype is not supported." )
124
125
if blocksparse_params is not None :
125
126
raise NotImplementedError ("Blocksparse is not supported." )
Original file line number Diff line number Diff line change 7
7
import torch
8
8
from torch .nn .functional import scaled_dot_product_attention
9
9
10
+ # yapf conflicts with isort for this block
11
+ # yapf: disable
10
12
from vllm .attention .backends .abstract import (AttentionBackend , AttentionImpl ,
11
13
AttentionLayer ,
12
14
AttentionMetadata ,
13
15
AttentionMetadataBuilder ,
14
- AttentionType )
16
+ AttentionType ,
17
+ is_quantized_kv_cache )
18
+ # yapf: enable
15
19
from vllm .attention .backends .utils import CommonAttentionState
16
20
from vllm .attention .ops .ipex_attn import PagedAttention
17
21
from vllm .attention .ops .paged_attn import PagedAttentionMetadata
@@ -427,7 +431,7 @@ def __init__(
427
431
raise ValueError (
428
432
f"Head size { head_size } is not supported by PagedAttention. "
429
433
f"Supported head sizes are: { supported_head_sizes } ." )
430
- if kv_cache_dtype != "auto" :
434
+ if is_quantized_kv_cache ( kv_cache_dtype ) :
431
435
raise NotImplementedError (
432
436
"Torch SDPA backend does not support FP8 KV cache. "
433
437
"Please use xFormers backend instead." )
Original file line number Diff line number Diff line change 4
4
5
5
import torch
6
6
7
- from vllm .attention .backends .abstract import AttentionType
7
+ from vllm .attention .backends .abstract import (AttentionType ,
8
+ is_quantized_kv_cache )
8
9
from vllm .attention .backends .mla .common import (MLACommonBackend ,
9
10
MLACommonImpl ,
10
11
MLACommonMetadata )
@@ -58,6 +59,10 @@ def __init__(
58
59
"are not implemented for "
59
60
"TritonMLAImpl" )
60
61
62
+ if is_quantized_kv_cache (self .kv_cache_dtype ):
63
+ raise NotImplementedError (
64
+ "TritonMLA with FP8 KV cache not yet supported" )
65
+
61
66
def _forward_decode (
62
67
self ,
63
68
q_nope : torch .Tensor ,
@@ -66,8 +71,6 @@ def _forward_decode(
66
71
attn_metadata : MLACommonMetadata ,
67
72
) -> torch .Tensor :
68
73
assert kv_c_and_k_pe_cache .numel () > 0
69
- if self .kv_cache_dtype .startswith ("fp8" ):
70
- raise NotImplementedError ("FP8 Triton MLA not yet supported" )
71
74
72
75
decode_meta = attn_metadata .decode_metadata
73
76
assert decode_meta is not None
Original file line number Diff line number Diff line change 7
7
import torch
8
8
9
9
from vllm .attention .backends .abstract import (AttentionBackend , AttentionImpl ,
10
- AttentionMetadata , AttentionType )
10
+ AttentionMetadata , AttentionType ,
11
+ is_quantized_kv_cache )
11
12
from vllm .attention .backends .utils import get_flash_attn_version
12
13
from vllm .attention .ops .triton_merge_attn_states import merge_attn_states
13
14
from vllm .logger import init_logger
@@ -180,6 +181,9 @@ def __init__(
180
181
else :
181
182
self .sliding_window = (sliding_window - 1 , 0 )
182
183
self .kv_cache_dtype = kv_cache_dtype
184
+ if is_quantized_kv_cache (self .kv_cache_dtype ):
185
+ raise NotImplementedError (
186
+ "FlashAttention V1 with FP8 KV cache not yet supported" )
183
187
if logits_soft_cap is None :
184
188
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
185
189
logits_soft_cap = 0
Original file line number Diff line number Diff line change 5
5
6
6
import torch
7
7
8
- from vllm .attention .backends .abstract import AttentionType
8
+ from vllm .attention .backends .abstract import (AttentionType ,
9
+ is_quantized_kv_cache )
9
10
from vllm .attention .ops .flashmla import (flash_mla_with_kvcache ,
10
11
get_mla_metadata ,
11
12
is_flashmla_supported )
@@ -115,6 +116,10 @@ def __init__(
115
116
"are not implemented for "
116
117
"FlashMLAImpl" )
117
118
119
+ if is_quantized_kv_cache (self .kv_cache_dtype ):
120
+ raise NotImplementedError (
121
+ "FlashMLA V1 with FP8 KV cache not yet supported" )
122
+
118
123
def _forward_decode (
119
124
self ,
120
125
q_nope : torch .Tensor ,
@@ -125,9 +130,6 @@ def _forward_decode(
125
130
assert kv_c_and_k_pe_cache .numel () > 0
126
131
assert attn_metadata .decode is not None
127
132
128
- if self .kv_cache_dtype .startswith ("fp8" ):
129
- raise NotImplementedError ("FP8 FlashMLA not yet supported" )
130
-
131
133
q = torch .cat ([q_nope , q_pe ], dim = - 1 )\
132
134
.unsqueeze (1 ) # Add seqlen dim of 1 (decode)
133
135
Original file line number Diff line number Diff line change 4
4
5
5
import torch
6
6
7
- from vllm .attention .backends .abstract import AttentionType
7
+ from vllm .attention .backends .abstract import (AttentionType ,
8
+ is_quantized_kv_cache )
8
9
from vllm .attention .ops .triton_decode_attention import decode_attention_fwd
9
10
from vllm .logger import init_logger
10
11
from vllm .v1 .attention .backends .mla .common import (MLACommonBackend ,
@@ -61,6 +62,10 @@ def __init__(
61
62
"are not implemented for "
62
63
"TritonMLAImpl" )
63
64
65
+ if is_quantized_kv_cache (self .kv_cache_dtype ):
66
+ raise NotImplementedError (
67
+ "TritonMLA V1 with FP8 KV cache not yet supported" )
68
+
64
69
def _forward_decode (
65
70
self ,
66
71
q_nope : torch .Tensor ,
You can’t perform that action at this time.
0 commit comments