Skip to content

Commit 845b0a2

Browse files
elvircrnstevhliujiqing-fengLRL-ModelCloudZX-ModelCloud
authored
Efficient Inference Kernel for SpQR (#34976)
* Resolve vptq conflict * Rename spqr package to spqr_quant * Get rid of aqlm mention * Start working on tests * Resolve ruff code checks * Ruff format * Isort * Test updates * Add gpu tag * Rename to modules_to_not_convert * Config update * Docs and config update * Docs and config update * Update to update_torch_dtype * spqr config parameter validation * Ruff update * Apply ruff fixes * Test fixes * Ruff update * Mark tests as @slow again; Ruff; Docstring update * Ruff * Remove absolute path * Resolve typo * Remove redundandt log * Check accelerate/spqr availability * Ruff fix * Check if the config contains proper shapes * Ruff test * Documentation update * overview update * Ruff checks * Ruff code quality * Make style * Update docs/source/en/quantization/spqr.md Co-authored-by: Steven Liu <[email protected]> * Update spqr.md * Enable gptqmodel (#35012) * gptqmodel Signed-off-by: jiqing-feng <[email protected]> * fix format Signed-off-by: jiqing-feng <[email protected]> * update readme Signed-off-by: jiqing-feng <[email protected]> * gptqmodel need use checkpoint_format (#1) * gptqmodel need use checkpoint_format * fix quantize * Update quantization_config.py * Update quantization_config.py * Update quantization_config.py --------- Co-authored-by: ZX-ModelCloud <[email protected]> Co-authored-by: Qubitium-ModelCloud <[email protected]> * Revert quantizer_gptq.py (#2) * revert quantizer_gptq.py change * pass **kwargs * limit gptqmodel and optimum version Signed-off-by: jiqing-feng <[email protected]> * fix format Signed-off-by: jiqing-feng <[email protected]> * fix warning Signed-off-by: jiqing-feng <[email protected]> * fix version check Signed-off-by: jiqing-feng <[email protected]> * revert unrelated changes Signed-off-by: jiqing-feng <[email protected]> * enable gptqmodel tests Signed-off-by: jiqing-feng <[email protected]> * fix requires gptq Signed-off-by: jiqing-feng <[email protected]> * Fix Transformer compat (#3) * revert quantizer_gptq.py change * pass **kwargs * add meta info * cleanup * cleanup * Update quantization_config.py * hf_select_quant_linear pass checkpoint_format and meta * fix GPTQTestCUDA * Update test_gptq.py * gptqmodel.hf_select_quant_linear() now does not select ExllamaV2 * cleanup * add backend * cleanup * cleanup * no need check exllama version * Update quantization_config.py * lower checkpoint_format and backend * check none * cleanup * Update quantization_config.py * fix self.use_exllama == False * spell * fix unittest * fix unittest --------- Co-authored-by: LRL <[email protected]> Co-authored-by: Qubitium-ModelCloud <[email protected]> * fix format Signed-off-by: jiqing-feng <[email protected]> * fix format again Signed-off-by: jiqing-feng <[email protected]> * update gptqmodel version (#6) * update gptqmodel version * update gptqmodel version * fix unit test (#5) * update gptqmodel version * update gptqmodel version * "not self.use_exllama" is not equivalent to "self.use_exllama==False" * fix unittest * update gptqmodel version * backend is loading_attibutes (#7) * fix format and tests Signed-off-by: jiqing-feng <[email protected]> * fix memory check Signed-off-by: jiqing-feng <[email protected]> * fix device mismatch Signed-off-by: jiqing-feng <[email protected]> * fix result check Signed-off-by: jiqing-feng <[email protected]> * Update src/transformers/quantizers/quantizer_gptq.py Co-authored-by: Marc Sun <[email protected]> * Update src/transformers/quantizers/quantizer_gptq.py Co-authored-by: Marc Sun <[email protected]> * Update src/transformers/quantizers/quantizer_gptq.py Co-authored-by: Marc Sun <[email protected]> * update tests Signed-off-by: jiqing-feng <[email protected]> * review: update docs (#10) * review: update docs (#12) * review: update docs * fix typo * update tests for gptqmodel Signed-off-by: jiqing-feng <[email protected]> * update document (#9) * update overview.md * cleanup * Update overview.md * Update overview.md * Update overview.md * update gptq.md * Update gptq.md * Update gptq.md * Update gptq.md * Update gptq.md * Update gptq.md * Update gptq.md --------- Co-authored-by: Qubitium-ModelCloud <[email protected]> * typo * doc note for asymmetric quant * typo with apple silicon(e) * typo for marlin * column name revert: review * doc rocm support * Update docs/source/en/quantization/gptq.md Co-authored-by: Steven Liu <[email protected]> * Update docs/source/en/quantization/gptq.md Co-authored-by: Steven Liu <[email protected]> * Update docs/source/en/quantization/gptq.md Co-authored-by: Steven Liu <[email protected]> * Update docs/source/en/quantization/gptq.md Co-authored-by: Steven Liu <[email protected]> * Update docs/source/en/quantization/overview.md Co-authored-by: Steven Liu <[email protected]> * Update docs/source/en/quantization/overview.md Co-authored-by: Steven Liu <[email protected]> --------- Signed-off-by: jiqing-feng <[email protected]> Co-authored-by: LRL-ModelCloud <[email protected]> Co-authored-by: ZX-ModelCloud <[email protected]> Co-authored-by: Qubitium-ModelCloud <[email protected]> Co-authored-by: ZX-ModelCloud <[email protected]> Co-authored-by: LRL <[email protected]> Co-authored-by: Marc Sun <[email protected]> Co-authored-by: Mohamed Mekkouri <[email protected]> Co-authored-by: Steven Liu <[email protected]> * Fix : Nemotron Processor in GGUF conversion (#35708) * fixing nemotron processor * make style * Update docs/source/en/quantization/spqr.md Co-authored-by: Arthur <[email protected]> * Add missing TOC to doc --------- Signed-off-by: jiqing-feng <[email protected]> Co-authored-by: Steven Liu <[email protected]> Co-authored-by: jiqing-feng <[email protected]> Co-authored-by: LRL-ModelCloud <[email protected]> Co-authored-by: ZX-ModelCloud <[email protected]> Co-authored-by: Qubitium-ModelCloud <[email protected]> Co-authored-by: ZX-ModelCloud <[email protected]> Co-authored-by: LRL <[email protected]> Co-authored-by: Marc Sun <[email protected]> Co-authored-by: Mohamed Mekkouri <[email protected]> Co-authored-by: Arthur <[email protected]>
1 parent c5506f4 commit 845b0a2

File tree

16 files changed

+591
-0
lines changed

16 files changed

+591
-0
lines changed

docker/transformers-quantization-latest-gpu/Dockerfile

+3
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ RUN python3 -m pip install --no-cache-dir aqlm[gpu]==1.0.2
5353
# Add vptq for quantization testing
5454
RUN python3 -m pip install --no-cache-dir vptq
5555

56+
# Add spqr for quantization testing
57+
RUN python3 -m pip install --no-cache-dir spqr_quant[gpu]
58+
5659
# Add hqq for quantization testing
5760
RUN python3 -m pip install --no-cache-dir hqq
5861

docs/source/en/_toctree.yml

+2
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,8 @@
166166
- local: quantization/aqlm
167167
title: AQLM
168168
- local: quantization/vptq
169+
title: SpQR
170+
- local: quantization/spqr
169171
title: VPTQ
170172
- local: quantization/quanto
171173
title: Quanto

docs/source/en/main_classes/quantization.md

+4
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,10 @@ Learn how to quantize models in the [Quantization](../quantization) guide.
8181

8282
[[autodoc]] BitNetConfig
8383

84+
## SpQRConfig
85+
86+
[[autodoc]] SpQRConfig
87+
8488
## FineGrainedFP8Config
8589

8690
[[autodoc]] FineGrainedFP8Config

docs/source/en/quantization/overview.md

+1
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ Use the table below to help you decide which quantization method to use.
6161
| [FBGEMM_FP8](./fbgemm_fp8.md) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | https://github.com/pytorch/FBGEMM |
6262
| [torchao](./torchao.md) | 🟢 | | 🟢 | 🔴 | 🟡 <sub>5</sub> | 🔴 | | 4/8 | | 🟢🔴 | 🟢 | https://github.com/pytorch/ao |
6363
| [VPTQ](./vptq.md) | 🔴 | 🔴 | 🟢 | 🟡 | 🔴 | 🔴 | 🟢 | 1/8 | 🔴 | 🟢 | 🟢 | https://github.com/microsoft/VPTQ |
64+
| [SpQR](./spqr.md) | 🔴 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🟢 | 3 | 🔴 | 🟢 | 🟢 | https://github.com/Vahe1994/SpQR/ |
6465
| [FINEGRAINED_FP8](./finegrained_fp8.md) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | |
6566
<Tip>
6667

docs/source/en/quantization/spqr.md

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
12+
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
13+
rendered properly in your Markdown viewer.
14+
15+
-->
16+
17+
# SpQR
18+
19+
[SpQR](https://github.com/Vahe1994/SpQR) quantization algorithm involves a 16x16 tiled bi-level group 3-bit quantization structure, with sparse outliers as detailed in [SpQR: A Sparse-Quantized Representation for Near-Lossless LLM Weight Compression](https://arxiv.org/abs/2306.03078).
20+
21+
To SpQR-quantize a model, refer to the [Vahe1994/SpQR](https://github.com/Vahe1994/SpQR) repository.
22+
23+
Load a pre-SpQR-quantized model in [`~PreTrainedModel.from_pretrained`].
24+
25+
```python
26+
from transformers import AutoTokenizer, AutoModelForCausalLM
27+
import torch
28+
29+
quantized_model = AutoModelForCausalLM.from_pretrained(
30+
"elvircrn/Llama-2-7b-SPQR-3Bit-16x16-red_pajama-hf",
31+
torch_dtype=torch.half,
32+
device_map="auto"
33+
)
34+
tokenizer = AutoTokenizer.from_pretrained("elvircrn/Llama-2-7b-SPQR-3Bit-16x16-red_pajama-hf")
35+
```

src/transformers/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -1029,6 +1029,7 @@
10291029
"HiggsConfig",
10301030
"HqqConfig",
10311031
"QuantoConfig",
1032+
"SpQRConfig",
10321033
"TorchAoConfig",
10331034
"VptqConfig",
10341035
],
@@ -6202,6 +6203,7 @@
62026203
HiggsConfig,
62036204
HqqConfig,
62046205
QuantoConfig,
6206+
SpQRConfig,
62056207
TorchAoConfig,
62066208
VptqConfig,
62076209
)

src/transformers/integrations/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@
106106
],
107107
"peft": ["PeftAdapterMixin"],
108108
"quanto": ["replace_with_quanto_layers"],
109+
"spqr": ["replace_with_spqr_linear"],
109110
"vptq": ["replace_with_vptq_linear"],
110111
}
111112

@@ -210,6 +211,7 @@
210211
)
211212
from .peft import PeftAdapterMixin
212213
from .quanto import replace_with_quanto_layers
214+
from .spqr import replace_with_spqr_linear
213215
from .vptq import replace_with_vptq_linear
214216

215217
try:

src/transformers/integrations/spqr.py

+122
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"SpQR (Sparse-Quantized Representation) integration file"
15+
16+
from ..utils import is_accelerate_available, is_spqr_available, is_torch_available
17+
18+
19+
if is_torch_available():
20+
import torch.nn as nn
21+
22+
23+
def replace_with_spqr_linear(
24+
model,
25+
quantization_config=None,
26+
modules_to_not_convert=None,
27+
current_key_name=None,
28+
has_been_replaced=False,
29+
):
30+
"""
31+
Public method that recursively replaces the Linear layers of the given model with SpQR quantized layers.
32+
`accelerate` is needed to use this method. Returns the converted model and a boolean that indicates if the
33+
conversion has been successful or not.
34+
35+
Args:
36+
model (`torch.nn.Module`):
37+
The model to convert, can be any `torch.nn.Module` instance.
38+
quantization_config (`SpQRConfig`):
39+
The quantization config object that contains the quantization parameters.
40+
modules_to_not_convert (`list[str]`, *optional*):
41+
A list of nn.Linear weights to not convert. If a parameter path is in the list (e.g. `lm_head.weight`), the corresponding module will not be
42+
converted.
43+
current_key_name (`list`, *optional*):
44+
A list that contains the current key name. This is used for recursion and should not be passed by the user.
45+
has_been_replaced (`bool`, *optional*):
46+
A boolean that indicates if the conversion has been successful or not. This is used for recursion and
47+
should not be passed by the user.
48+
"""
49+
if modules_to_not_convert is None:
50+
modules_to_not_convert = []
51+
52+
if is_accelerate_available():
53+
from accelerate import init_empty_weights
54+
if is_spqr_available():
55+
from spqr_quant import QuantizedLinear
56+
57+
for name, module in model.named_children():
58+
if current_key_name is None:
59+
current_key_name = []
60+
current_key_name.append(name)
61+
62+
if isinstance(module, nn.Linear):
63+
# Check if the current key is not in the `modules_to_not_convert`
64+
if ".".join(current_key_name) + ".weight" not in modules_to_not_convert:
65+
with init_empty_weights():
66+
tensor_name = ".".join(current_key_name)
67+
68+
shapes = quantization_config.shapes
69+
shapes_keys = shapes.keys()
70+
71+
shapes_valid = (
72+
f"{tensor_name}.dense_weights.shape" in shapes_keys
73+
and f"{tensor_name}.row_offsets.shape" in shapes_keys
74+
and f"{tensor_name}.col_vals.shape" in shapes_keys
75+
and f"{tensor_name}.in_perm.shape" in shapes_keys
76+
)
77+
78+
if not shapes_valid:
79+
raise ValueError(
80+
f"The SpQR quantization config does not contain the shape "
81+
f"configuration for {tensor_name}. This indicates that the "
82+
f"configuration is either invalid or corrupted."
83+
)
84+
85+
dense_weights_shape = shapes[f"{tensor_name}.dense_weights.shape"]
86+
row_offsets_shape = shapes[f"{tensor_name}.row_offsets.shape"]
87+
col_vals_shape = shapes[f"{tensor_name}.col_vals.shape"]
88+
in_perm_shape = shapes[f"{tensor_name}.in_perm.shape"]
89+
90+
in_features = module.in_features
91+
out_features = module.out_features
92+
93+
model._modules[name] = QuantizedLinear.create_placehodler(
94+
rows=out_features,
95+
cols=in_features,
96+
bits=quantization_config.bits,
97+
beta1=quantization_config.beta1,
98+
beta2=quantization_config.beta2,
99+
dense_weights_shape=dense_weights_shape,
100+
row_offsets_shape=row_offsets_shape,
101+
col_vals_shape=col_vals_shape,
102+
in_perm_shape=in_perm_shape,
103+
)
104+
has_been_replaced = True
105+
106+
# Store the module class in case we need to transpose the weight later
107+
model._modules[name].source_cls = type(module)
108+
# Force requires grad to False to avoid unexpected errors
109+
model._modules[name].requires_grad_(False)
110+
else:
111+
pass
112+
if len(list(module.children())) > 0:
113+
_, has_been_replaced = replace_with_spqr_linear(
114+
module,
115+
quantization_config=quantization_config,
116+
modules_to_not_convert=modules_to_not_convert,
117+
current_key_name=current_key_name,
118+
has_been_replaced=has_been_replaced,
119+
)
120+
# Remove the last key for recursion
121+
current_key_name.pop(-1)
122+
return model, has_been_replaced

src/transformers/quantizers/auto.py

+4
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
QuantizationConfigMixin,
3232
QuantizationMethod,
3333
QuantoConfig,
34+
SpQRConfig,
3435
TorchAoConfig,
3536
VptqConfig,
3637
)
@@ -47,6 +48,7 @@
4748
from .quantizer_higgs import HiggsHfQuantizer
4849
from .quantizer_hqq import HqqHfQuantizer
4950
from .quantizer_quanto import QuantoHfQuantizer
51+
from .quantizer_spqr import SpQRHfQuantizer
5052
from .quantizer_torchao import TorchAoHfQuantizer
5153
from .quantizer_vptq import VptqHfQuantizer
5254

@@ -66,6 +68,7 @@
6668
"torchao": TorchAoHfQuantizer,
6769
"bitnet": BitNetHfQuantizer,
6870
"vptq": VptqHfQuantizer,
71+
"spqr": SpQRHfQuantizer,
6972
"fp8": FineGrainedFP8HfQuantizer,
7073
}
7174

@@ -84,6 +87,7 @@
8487
"torchao": TorchAoConfig,
8588
"bitnet": BitNetConfig,
8689
"vptq": VptqConfig,
90+
"spqr": SpQRConfig,
8791
"fp8": FineGrainedFP8Config,
8892
}
8993

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/lic enses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import TYPE_CHECKING, Optional
15+
16+
from .base import HfQuantizer
17+
18+
19+
if TYPE_CHECKING:
20+
from ..modeling_utils import PreTrainedModel
21+
22+
from ..integrations import replace_with_spqr_linear
23+
from ..utils import is_accelerate_available, is_spqr_available, is_torch_available, logging
24+
from ..utils.quantization_config import QuantizationConfigMixin
25+
26+
27+
if is_torch_available():
28+
import torch
29+
30+
logger = logging.get_logger(__name__)
31+
32+
33+
class SpQRHfQuantizer(HfQuantizer):
34+
"""
35+
Quantizer of the SpQR method. Enables the loading of prequantized models.
36+
"""
37+
38+
def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
39+
super().__init__(quantization_config, **kwargs)
40+
self.quantization_config = quantization_config
41+
42+
def validate_environment(self, *args, **kwargs):
43+
if not torch.cuda.is_available():
44+
raise RuntimeError("GPU is required to run SpQR quantized model.")
45+
46+
if not is_accelerate_available():
47+
raise ImportError("Using `spqr` quantization requires Accelerate: `pip install accelerate`")
48+
49+
if not is_spqr_available():
50+
raise ImportError("Using `spqr` quantization requires SpQR: `pip install spqr_quant[gpu]`")
51+
52+
def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
53+
if torch_dtype is None:
54+
torch_dtype = torch.float16
55+
logger.info("Assuming SpQR inference on GPU and loading the model in `torch.float16`.")
56+
elif torch_dtype != torch.float16:
57+
raise ValueError(
58+
"You cannot use any type other than torch.float16 for SpQR. Please either leave it None or set it to"
59+
"torch.float16 explicitly."
60+
)
61+
return torch_dtype
62+
63+
def _process_model_before_weight_loading(
64+
self,
65+
model: "PreTrainedModel",
66+
**kwargs,
67+
):
68+
replace_with_spqr_linear(
69+
model,
70+
quantization_config=self.quantization_config,
71+
modules_to_not_convert=self.quantization_config.modules_to_not_convert,
72+
)
73+
model.config.quantization_config = self.quantization_config
74+
75+
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
76+
return model
77+
78+
@property
79+
def is_trainable(self, model: Optional["PreTrainedModel"] = None):
80+
return False
81+
82+
def is_serializable(self, safe_serialization=None):
83+
return True

src/transformers/testing_utils.py

+8
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@
121121
is_seqio_available,
122122
is_soundfile_available,
123123
is_spacy_available,
124+
is_spqr_available,
124125
is_sudachi_available,
125126
is_sudachi_projection_available,
126127
is_tensorflow_probability_available,
@@ -1191,6 +1192,13 @@ def require_vptq(test_case):
11911192
return unittest.skipUnless(is_vptq_available(), "test requires vptq")(test_case)
11921193

11931194

1195+
def require_spqr(test_case):
1196+
"""
1197+
Decorator marking a test that requires spqr
1198+
"""
1199+
return unittest.skipUnless(is_spqr_available(), "test requires spqr")(test_case)
1200+
1201+
11941202
def require_eetq(test_case):
11951203
"""
11961204
Decorator marking a test that requires eetq

src/transformers/utils/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@
193193
is_soundfile_available,
194194
is_spacy_available,
195195
is_speech_available,
196+
is_spqr_available,
196197
is_sudachi_available,
197198
is_sudachi_projection_available,
198199
is_tensorflow_probability_available,

src/transformers/utils/import_utils.py

+5
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
201201
_blobfile_available = _is_package_available("blobfile")
202202
_liger_kernel_available = _is_package_available("liger_kernel")
203203
_triton_available = _is_package_available("triton")
204+
_spqr_available = _is_package_available("spqr_quant")
204205

205206
_torch_version = "N/A"
206207
_torch_available = False
@@ -1213,6 +1214,10 @@ def is_speech_available():
12131214
return _torchaudio_available
12141215

12151216

1217+
def is_spqr_available():
1218+
return _spqr_available
1219+
1220+
12161221
def is_phonemizer_available():
12171222
return _phonemizer_available
12181223

0 commit comments

Comments
 (0)