Skip to content

Commit de0526f

Browse files
kewang-xlnxkewang2mgoin
authored
[Misc][Quark] Upstream Quark format to VLLM (#10765)
Signed-off-by: kewang-xlnx <[email protected]> Signed-off-by: kewang2 <[email protected]> Co-authored-by: kewang2 <[email protected]> Co-authored-by: Michael Goin <[email protected]>
1 parent 5ecf3e0 commit de0526f

32 files changed

+1264
-70
lines changed

tests/quantization/test_quark.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
"""Test model set-up and weight loading for quark-quantized models.
2+
3+
Run `pytest tests/quantization/test_quark.py`.
4+
"""
5+
6+
import torch
7+
8+
from vllm.model_executor.layers.quantization.quark.quark import ( # noqa: E501
9+
QuarkLinearMethod, QuarkW8A8Fp8)
10+
11+
12+
def test_quark_fp8(vllm_runner):
13+
model_path = "amd/Llama-3.1-8B-Instruct-FP8-KV-Quark-test"
14+
with vllm_runner(model_path) as llm:
15+
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
16+
layer = model.model.layers[0]
17+
18+
qkv_proj = layer.self_attn.qkv_proj
19+
20+
assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
21+
assert isinstance(qkv_proj.scheme, QuarkW8A8Fp8)
22+
23+
if isinstance(qkv_proj.scheme, QuarkW8A8Fp8):
24+
assert len(qkv_proj.input_scale.shape) == 0
25+
assert qkv_proj.weight.dtype is torch.float8_e4m3fn
26+
#assert qkv_proj.weight.dtype is torch.float8_e4m3fnuz
27+
assert len(qkv_proj.weight_scale.shape) == 0
28+
29+
output = llm.generate_greedy("Hello my name is", max_tokens=20)
30+
assert output

vllm/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -553,7 +553,7 @@ def _verify_quantization(self) -> None:
553553
optimized_quantization_methods = [
554554
"fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin",
555555
"awq_marlin", "fbgemm_fp8", "compressed_tensors",
556-
"compressed-tensors", "experts_int8"
556+
"compressed-tensors", "experts_int8", "quark"
557557
]
558558
if self.quantization is not None:
559559
self.quantization = self.quantization.lower()

vllm/model_executor/layers/linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
"MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod",
3333
"TPUInt8LinearMethod", "GPTQLinearMethod", "FBGEMMFp8LinearMethod",
3434
"ModelOptFp8LinearMethod", "IPEXAWQLinearMethod", "IPEXGPTQLinearMethod",
35-
"HQQMarlinMethod"
35+
"HQQMarlinMethod", "QuarkLinearMethod"
3636
]
3737

3838

vllm/model_executor/layers/quantization/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
"experts_int8",
2727
"neuron_quant",
2828
"ipex",
29+
"quark"
2930
]
3031

3132

@@ -34,6 +35,8 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
3435
raise ValueError(f"Invalid quantization method: {quantization}")
3536

3637
# lazy import to avoid triggering `torch.compile` too early
38+
from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig
39+
3740
from .aqlm import AQLMConfig
3841
from .awq import AWQConfig
3942
from .awq_marlin import AWQMarlinConfig
@@ -79,6 +82,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
7982
"experts_int8": ExpertsInt8Config,
8083
"neuron_quant": NeuronQuantConfig,
8184
"ipex": IPEXConfig,
85+
"quark": QuarkConfig
8286
}
8387

8488
return method_to_config[quantization]

vllm/model_executor/layers/quantization/base_config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,3 +133,6 @@ def get_quant_method(self, layer: torch.nn.Module,
133133
method.
134134
"""
135135
raise NotImplementedError
136+
137+
def get_cache_scale(self, name: str) -> Optional[str]:
138+
return None

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,22 @@ def get_scheme(
412412
self._check_scheme_supported(scheme.get_min_capability())
413413
return scheme
414414

415+
def get_cache_scale(self, name: str) -> Optional[str]:
416+
"""
417+
Check whether the param name matches the format for k/v cache scales
418+
in compressed-tensors. If this is the case, return its equivalent
419+
param name expected by vLLM
420+
421+
:param name: param name
422+
:return: matching param name for KV cache scale in vLLM
423+
"""
424+
if name.endswith(".output_scale") and ".k_proj" in name:
425+
return name.replace(".k_proj.output_scale", ".attn.k_scale")
426+
if name.endswith(".output_scale") and ".v_proj" in name:
427+
return name.replace(".v_proj.output_scale", ".attn.v_scale")
428+
# If no matches, return None
429+
return None
430+
415431
@staticmethod
416432
def supports_cutlass_24(
417433
weight_quant: Optional[QuantizationArgs],

vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,10 @@ def triton_scaled_mm(input: torch.Tensor,
136136
assert N > 0 and K > 0 and M > 0
137137
assert weight.shape[0] == K
138138
assert input.dtype == weight.dtype
139+
140+
scale_a = scale_a.reshape(-1, 1) if scale_a.dim() <= 1 else scale_a
141+
scale_b = scale_b.reshape(-1, 1) if scale_b.dim() <= 1 else scale_b
142+
139143
assert scale_a.dtype == scale_b.dtype and scale_a.is_floating_point()
140144
assert scale_a.shape == torch.Size([1, 1]) or scale_a.shape == torch.Size(
141145
[M, 1])

vllm/model_executor/layers/quantization/compressed_tensors/utils.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -133,23 +133,6 @@ def _find_first_match(value: str,
133133
return None
134134

135135

136-
def get_compressed_tensors_cache_scale(name: str) -> Optional[str]:
137-
"""
138-
Check whether the param name matches the format for k/v cache scales
139-
in compressed-tensors. If this is the case, return its equivalent
140-
param name expected by vLLM
141-
142-
:param name: param name
143-
:return: matching param name for KV cache scale in vLLM
144-
"""
145-
if name.endswith(".output_scale") and ".k_proj" in name:
146-
return name.replace(".k_proj.output_scale", ".attn.k_scale")
147-
if name.endswith(".output_scale") and ".v_proj" in name:
148-
return name.replace(".v_proj.output_scale", ".attn.v_scale")
149-
# If no matches, return None
150-
return None
151-
152-
153136
def _is_equal_or_regex_match(value: str,
154137
target: str,
155138
check_contains: bool = False) -> bool:

vllm/model_executor/layers/quantization/quark/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)