Skip to content

Commit e7fbeda

Browse files
rasmithProExpertProg
authored andcommitted
[Quantization][FP8] Add support for FP8 models with input_scale for output projection and QK quantization (vllm-project#15734)
Signed-off-by: Randall Smith <[email protected]> Signed-off-by: Luka Govedič <[email protected]> Co-authored-by: Luka Govedič <[email protected]>
1 parent 01882b7 commit e7fbeda

File tree

8 files changed

+105
-20
lines changed

8 files changed

+105
-20
lines changed

vllm/attention/backends/abstract.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ class AttentionLayer(Protocol):
237237
_v_scale: torch.Tensor
238238
_k_scale_float: float
239239
_v_scale_float: float
240+
_prob_scale: torch.Tensor
240241

241242
def forward(
242243
self,

vllm/attention/backends/rocm_flash_attn.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -766,6 +766,12 @@ def forward(
766766
query.dtype,
767767
seq_lens,
768768
make_attn_mask=causal_mask) # type: ignore
769+
use_fp8_scales = (layer._q_scale and layer._k_scale
770+
and layer._v_scale and layer._prob_scale
771+
and self.kv_cache_dtype == "fp8")
772+
full_scales = (
773+
layer._q_scale, layer._k_scale, layer._v_scale,
774+
layer._prob_scale) if use_fp8_scales else None
769775
self.triton_attn_func(
770776
query,
771777
key,
@@ -779,6 +785,7 @@ def forward(
779785
self.scale,
780786
attn_masks[0][None]
781787
if attn_masks is not None else None,
788+
full_scales,
782789
)
783790
elif self.use_naive_attn:
784791
if self.num_kv_heads != self.num_heads:

vllm/attention/layer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def __init__(
9090
# FlashAttn doesn't support quantizing the kv-cache only
9191
# but requires q to be quantized as well.
9292
self._q_scale = torch.tensor(1.0, dtype=torch.float32)
93+
self._prob_scale = torch.tensor(1.0, dtype=torch.float32)
9394

9495
# We also keep the float32 versions of k/v_scale for attention
9596
# backends that don't support tensors (Flashinfer)

vllm/config.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3767,6 +3767,17 @@ def _get_quantization_config(
37673767
return quant_config
37683768
return None
37693769

3770+
@staticmethod
3771+
def get_quantization_config(
3772+
model_config: ModelConfig,
3773+
load_config: LoadConfig) -> Optional[QuantizationConfig]:
3774+
import copy
3775+
3776+
# For some reason, the _ version of this modifies the model_config
3777+
# object, so using deepcopy to avoid this problem.
3778+
return VllmConfig._get_quantization_config(copy.deepcopy(model_config),
3779+
load_config)
3780+
37703781
def with_hf_config(
37713782
self,
37723783
hf_config: PretrainedConfig,

vllm/engine/arg_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1368,6 +1368,23 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
13681368
recommend_to_remove=False)
13691369
return False
13701370

1371+
if current_platform.is_rocm():
1372+
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
1373+
load_config = self.create_load_config()
1374+
quantization_config = VllmConfig.get_quantization_config(
1375+
model_config, load_config)
1376+
if isinstance(quantization_config, Fp8Config):
1377+
_raise_or_fallback(feature_name="fp8 for ROCm",
1378+
recommend_to_remove=False)
1379+
return False
1380+
from vllm.model_executor.layers.quantization.quark.quark import (
1381+
QuarkConfig)
1382+
1383+
if isinstance(quantization_config, QuarkConfig
1384+
) and quantization_config.has_fp8_layer_weights():
1385+
_raise_or_fallback(feature_name="Quark fp8 for ROCm",
1386+
recommend_to_remove=False)
1387+
13711388
# No Fp8 KV cache so far.
13721389
if self.kv_cache_dtype != "auto":
13731390
fp8_attention = self.kv_cache_dtype.startswith("fp8")

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,11 @@ def get_cache_scale(self, name: str) -> Optional[str]:
140140
return name.replace(".k_proj.output_scale", ".attn.k_scale")
141141
if name.endswith(".output_scale") and ".v_proj" in name:
142142
return name.replace(".v_proj.output_scale", ".attn.v_scale")
143+
if name.endswith(".output_scale") and ".q_proj" in name:
144+
return name.replace(".q_proj.output_scale", ".attn.q_scale")
145+
if name.endswith("self_attn.prob_output_scale"):
146+
return name.replace(".prob_output_scale", ".attn.prob_scale")
147+
# If no matches, return None
143148
return None
144149

145150

vllm/model_executor/layers/quantization/kv_cache.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ def create_weights(self, layer: torch.nn.Module):
3838
requires_grad=False)
3939
layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0),
4040
requires_grad=False)
41+
# Initialize P = softmax(QK^T) scales
42+
layer.prob_scale = torch.nn.Parameter(torch.tensor(-1.0),
43+
requires_grad=False)
4144

4245
def apply(self, layer: torch.nn.Module) -> torch.Tensor:
4346
raise RuntimeError(
@@ -97,5 +100,38 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
97100
"may cause accuracy issues. Please make sure k/v_scale "
98101
"scaling factors are available in the fp8 checkpoint.")
99102

103+
if layer.q_scale > 0.0:
104+
q_scale = layer.q_scale
105+
if current_platform.is_fp8_fnuz():
106+
q_scale *= 2
107+
layer.calculate_kv_scales = False
108+
else:
109+
q_scale = 1.0
110+
if layer.prob_scale > 0.0:
111+
prob_scale = layer.prob_scale
112+
if current_platform.is_fp8_fnuz():
113+
prob_scale *= 2
114+
else:
115+
prob_scale = 1.0
116+
117+
is_singleton_float = lambda x: isinstance(x, float) or isinstance(
118+
x, torch.Tensor) and x.numel() == 1 and x.is_floating_point()
119+
if not is_singleton_float(q_scale) or not is_singleton_float(
120+
prob_scale):
121+
raise ValueError("Only support per-tensor scaling factor"
122+
"for fp8-quantized Q/prob")
123+
124+
# These are used in the final Attention.forward()
125+
layer._q_scale.copy_(q_scale)
126+
layer._prob_scale.copy_(prob_scale)
127+
if q_scale == 1.0 or prob_scale == 1.0:
128+
logger.warning_once(
129+
f"Using Q scale {q_scale} and prob scale {prob_scale} "
130+
"with fp8 attention. This may cause accuracy issues. "
131+
"Please make sure Q/prob scaling factors are "
132+
"available in the fp8 checkpoint.")
133+
100134
del layer.k_scale
101135
del layer.v_scale
136+
del layer.q_scale
137+
del layer.prob_scale

vllm/model_executor/layers/quantization/quark/quark.py

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
import fnmatch
4-
import re
54
from typing import Any, Dict, List, Optional, cast
65

76
import torch
@@ -125,6 +124,13 @@ def from_config(cls, config: Dict[str, Any]) -> "QuarkConfig":
125124
for q_config in q_configs:
126125
q_config["output_tensors"] = None
127126

127+
# In case q_proj output is also quantized, remove the configuration
128+
# to keep qkv consistency.
129+
q_proj_q_config = cast(Dict[str, Any],
130+
layer_quant_config.get("*q_proj"))
131+
if q_proj_q_config is not None:
132+
q_proj_q_config["output_tensors"] = None
133+
128134
return cls(quant_config=config,
129135
kv_cache_group=kv_cache_group,
130136
kv_cache_config=kv_cache_config,
@@ -289,29 +295,30 @@ def get_cache_scale(self, name: str) -> Optional[str]:
289295
:param name: param name
290296
:return: matching param name for KV cache scale in vLLM
291297
"""
292-
if self.kv_cache_group is None or len(self.kv_cache_group) == 0:
293-
return None
294-
295-
kv_proj_names = [
296-
re.split(r"[*.]", kv_cache)[-1] for kv_cache in self.kv_cache_group
297-
]
298-
if name.endswith(".output_scale"):
299-
if len(kv_proj_names) == 1 and kv_proj_names[0] in name:
300-
kv_output_scale_name = "." + kv_proj_names[0] + ".output_scale"
301-
return name.replace(kv_output_scale_name, ".attn.k_scale")
302-
303-
elif len(kv_proj_names) == 2:
304-
for kv_proj_name in kv_proj_names:
305-
if kv_proj_name in name and kv_proj_name == "k_proj":
306-
return name.replace(".k_proj.output_scale",
307-
".attn.k_scale")
308-
elif kv_proj_name in name and kv_proj_name == "v_proj":
309-
return name.replace(".v_proj.output_scale",
310-
".attn.v_scale")
298+
if name.endswith(".output_scale") and ".k_proj" in name:
299+
return name.replace(".k_proj.output_scale", ".attn.k_scale")
300+
if name.endswith(".output_scale") and ".v_proj" in name:
301+
return name.replace(".v_proj.output_scale", ".attn.v_scale")
302+
if name.endswith(".output_scale") and ".q_proj" in name:
303+
return name.replace(".q_proj.output_scale", ".attn.q_scale")
304+
if name.endswith("self_attn.prob_output_scale"):
305+
return name.replace(".prob_output_scale", ".attn.prob_scale")
311306

312307
# If no matches, return None
313308
return None
314309

310+
def has_fp8_layer_weights(self):
311+
layer_quant_config = self.quant_config.get("layer_quant_config")
312+
to_dict = lambda obj: cast(Dict[str, Any], obj) or {}
313+
return any([
314+
'fp8' in cast(
315+
str,
316+
to_dict(
317+
to_dict(to_dict(layer_quant_config).get(layer_name)).get(
318+
"weight")).get("dtype"))
319+
for layer_name in ["*v_proj", "*k_proj", "*q_proj"]
320+
])
321+
315322

316323
class QuarkLinearMethod(LinearMethodBase):
317324

0 commit comments

Comments
 (0)