Skip to content

Commit 32eb0da

Browse files
authored
[Misc] Support register quantization method out-of-tree (#11969)
1 parent 6d0e3d3 commit 32eb0da

File tree

2 files changed

+158
-0
lines changed

2 files changed

+158
-0
lines changed
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
"""Tests register custom quantization config.
2+
3+
See https://github.com/vllm-project/vllm/issues/11926 for more details.
4+
5+
Run `pytest tests/quantization/test_register_quantization_config.py`.
6+
"""
7+
from typing import Any, Dict, List, Optional
8+
9+
import pytest
10+
import torch
11+
import torch.nn.functional as F
12+
13+
from vllm.model_executor.layers.linear import LinearBase # noqa: E501
14+
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
15+
from vllm.model_executor.layers.quantization import (
16+
get_quantization_config, register_quantization_config)
17+
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
18+
QuantizationConfig)
19+
20+
21+
class FakeQuantLinearMethod(UnquantizedLinearMethod):
22+
"""Fake quantization linear method for per-token dynamic quantization."""
23+
24+
def __init__(self, num_bits: int = 8) -> None:
25+
"""Initialize the quantization method."""
26+
super().__init__()
27+
self.num_bits = num_bits
28+
29+
def apply(self,
30+
layer: "torch.nn.Module",
31+
x: "torch.Tensor",
32+
bias: Optional["torch.Tensor"] = None) -> "torch.Tensor":
33+
"""Perform fake quantization before the linear layer."""
34+
35+
# Calculate the scales dynamically
36+
max_val = torch.amax(x, dim=(0, -1), keepdims=True)
37+
min_val = torch.amin(x, dim=(0, -1), keepdims=True)
38+
scales = (max_val - min_val) / (2**self.num_bits - 1)
39+
40+
# Fake quantize the input
41+
quant_x = torch.clamp(torch.round(x / scales), -2**(self.num_bits - 1),
42+
2**(self.num_bits - 1) - 1)
43+
dequant_x = quant_x * scales
44+
45+
return F.linear(dequant_x, layer.weight, bias)
46+
47+
48+
@register_quantization_config("custom_quant")
49+
class CustomQuantConfig(QuantizationConfig):
50+
"""Custom quantization config for per-token dynamic fake quantization."""
51+
52+
def __init__(self, num_bits: int = 8) -> None:
53+
"""Initialize the quantization config."""
54+
self.num_bits = num_bits
55+
56+
def get_name(self) -> str:
57+
"""Name of the quantization method."""
58+
return "custom_quant"
59+
60+
def get_supported_act_dtypes(self) -> List["torch.dtype"]:
61+
"""List of supported activation dtypes."""
62+
return [torch.float16, torch.bfloat16]
63+
64+
@classmethod
65+
def get_min_capability(cls) -> int:
66+
"""Minimum GPU capability to support the quantization method."""
67+
return -1
68+
69+
@staticmethod
70+
def get_config_filenames() -> List[str]:
71+
"""List of filenames to search for in the model directory."""
72+
return []
73+
74+
@classmethod
75+
def from_config(cls, config: Dict[str, Any]) -> "CustomQuantConfig":
76+
"""Create a config class from the model's quantization config."""
77+
return CustomQuantConfig(num_bits=config.get("num_bits", 8))
78+
79+
def get_quant_method(self, layer: "torch.nn.Module",
80+
prefix: str) -> Optional["FakeQuantLinearMethod"]:
81+
"""Get the quantize method to use for the quantized layer."""
82+
if isinstance(layer, LinearBase):
83+
return FakeQuantLinearMethod(num_bits=self.num_bits)
84+
return None
85+
86+
87+
def test_register_quantization_config():
88+
"""Test register custom quantization config."""
89+
90+
# The quantization method `custom_quant` should be registered.
91+
assert get_quantization_config("custom_quant") == CustomQuantConfig
92+
93+
# The quantization method `custom_quant` is already exists,
94+
# should raise an error.
95+
with pytest.raises(ValueError):
96+
register_quantization_config("custom_quant")(CustomQuantConfig)
97+
98+
99+
@pytest.mark.parametrize(argnames="model",
100+
argvalues=[
101+
"meta-llama/Meta-Llama-3-8B-Instruct",
102+
])
103+
def test_custom_quant(vllm_runner, model):
104+
"""Test infer with the custom quantization method."""
105+
with vllm_runner(model_name=model,
106+
quantization="custom_quant",
107+
enforce_eager=True) as llm:
108+
109+
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
110+
layer = model.model.layers[0]
111+
qkv_proj = layer.self_attn.qkv_proj
112+
113+
# Check the quantization method is FakeQuantLinearMethod
114+
assert isinstance(qkv_proj.quant_method, FakeQuantLinearMethod)
115+
116+
output = llm.generate_greedy("Hello my name is", max_tokens=20)
117+
assert output

vllm/model_executor/layers/quantization/__init__.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,45 @@
2929
"quark"
3030
]
3131

32+
# The customized quantization methods which will be added to this dict.
33+
_CUSTOMIZED_METHOD_TO_QUANT_CONFIG = {}
34+
35+
36+
def register_quantization_config(quantization: str):
37+
"""Register a customized vllm quantization config.
38+
39+
When a quantization method is not supported by vllm, you can register a customized
40+
quantization config to support it.
41+
42+
Args:
43+
quantization (str): The quantization method name.
44+
45+
Examples:
46+
>>> from vllm.model_executor.layers.quantization import register_quantization_config
47+
>>> from vllm.model_executor.layers.quantization import get_quantization_config
48+
>>> from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
49+
>>>
50+
>>> @register_quantization_config("my_quant")
51+
... class MyQuantConfig(QuantizationConfig):
52+
... pass
53+
>>>
54+
>>> get_quantization_config("my_quant")
55+
<class 'MyQuantConfig'>
56+
""" # noqa: E501
57+
58+
def _wrapper(quant_config_cls):
59+
if quantization in QUANTIZATION_METHODS:
60+
raise ValueError(
61+
f"The quantization method `{quantization}` is already exists.")
62+
if not issubclass(quant_config_cls, QuantizationConfig):
63+
raise ValueError("The quantization config must be a subclass of "
64+
"`QuantizationConfig`.")
65+
_CUSTOMIZED_METHOD_TO_QUANT_CONFIG[quantization] = quant_config_cls
66+
QUANTIZATION_METHODS.append(quantization)
67+
return quant_config_cls
68+
69+
return _wrapper
70+
3271

3372
def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
3473
if quantization not in QUANTIZATION_METHODS:
@@ -84,6 +123,8 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
84123
"ipex": IPEXConfig,
85124
"quark": QuarkConfig
86125
}
126+
# Update the `method_to_config` with customized quantization methods.
127+
method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG)
87128

88129
return method_to_config[quantization]
89130

0 commit comments

Comments
 (0)