Skip to content

Commit 71dfb52

Browse files
committed
Add GPTQ support
1 parent ad5f2fe commit 71dfb52

File tree

6 files changed

+170
-15
lines changed

6 files changed

+170
-15
lines changed

vllm/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import torch
44
from transformers import PretrainedConfig
5+
from auto_gptq import BaseQuantizeConfig
56

67
from vllm.logger import init_logger
78
from vllm.transformers_utils.config import get_config
@@ -55,6 +56,10 @@ def __init__(
5556
self.seed = seed
5657

5758
self.hf_config = get_config(model, trust_remote_code)
59+
try:
60+
self.quantize_config = BaseQuantizeConfig.from_pretrained(model)
61+
except:
62+
self.quantize_config = None
5863
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
5964
self._verify_tokenizer_mode()
6065

vllm/model_executor/model_loader.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,14 @@
33

44
import torch
55
import torch.nn as nn
6+
from accelerate import init_on_device
67
from transformers import PretrainedConfig
8+
from auto_gptq.modeling._utils import autogptq_post_init
79

810
from vllm.config import ModelConfig
911
from vllm.model_executor.models import * # pylint: disable=wildcard-import
1012
from vllm.model_executor.weight_utils import initialize_dummy_weights
13+
from vllm.model_executor.quantize import make_quant, find_layers
1114

1215
# TODO(woosuk): Lazy-load the model classes.
1316
_MODEL_REGISTRY = {
@@ -46,7 +49,25 @@ def get_model(model_config: ModelConfig) -> nn.Module:
4649

4750
# Create a model instance.
4851
# The weights will be initialized as empty tensors.
49-
model = model_class(model_config.hf_config)
52+
if model_config.quantize_config:
53+
with init_on_device(device=torch.device("cpu")):
54+
model = model_class(model_config.hf_config)
55+
layers = find_layers(model)
56+
ignore_layers = [model_class.lm_head_name] + model_class.outside_layer_modules
57+
for name in list(layers.keys()):
58+
if any([name.startswith(ignore_layer) for ignore_layer in ignore_layers]):
59+
del layers[name]
60+
61+
make_quant(
62+
model,
63+
layers,
64+
model_config.quantize_config.bits,
65+
model_config.quantize_config.group_size,
66+
desc_act=model_config.quantize_config.desc_act,
67+
)
68+
model.quantize_config = model_config.quantize_config
69+
else:
70+
model = model_class(model_config.hf_config)
5071
if model_config.use_dummy_weights:
5172
model = model.cuda()
5273
# NOTE(woosuk): For accurate performance evaluation, we assign
@@ -57,4 +78,6 @@ def get_model(model_config: ModelConfig) -> nn.Module:
5778
model.load_weights(model_config.model, model_config.download_dir,
5879
model_config.use_np_weights)
5980
model = model.cuda()
81+
if model_config.quantize_config:
82+
model = autogptq_post_init(model, use_act_order=model_config.quantize_config.desc_act)
6083
return model.eval()

vllm/model_executor/models/llama.py

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,8 @@ def forward(
229229

230230

231231
class LlamaForCausalLM(nn.Module):
232+
lm_head_name = "lm_head"
233+
outside_layer_modules = ["model.embed_tokens", "model.norm"]
232234

233235
def __init__(self, config):
234236
super().__init__()
@@ -301,11 +303,23 @@ def load_weights(self,
301303
if weight_name not in name:
302304
continue
303305
param = state_dict[name.replace(weight_name, "qkv_proj")]
304-
305-
loaded_weight = loaded_weight[
306-
shard_size * tensor_model_parallel_rank:shard_size *
307-
(tensor_model_parallel_rank + 1)]
308-
param_slice = param.data[offset:offset + shard_size]
306+
if "g_idx" in name:
307+
param.data.copy_(loaded_weight)
308+
is_attention_weight = True
309+
continue
310+
if any(key in name for key in ('qweight', 'qzeros', 'scales')):
311+
if 'qzeros' in name:
312+
shard_size = shard_size // 32 * self.quantize_config.bits
313+
offset = offset // 32 * self.quantize_config.bits
314+
loaded_weight = loaded_weight[:,
315+
shard_size * tensor_model_parallel_rank:shard_size *
316+
(tensor_model_parallel_rank + 1)]
317+
param_slice = param.data[:, offset:offset + shard_size]
318+
else:
319+
loaded_weight = loaded_weight[
320+
shard_size * tensor_model_parallel_rank:shard_size *
321+
(tensor_model_parallel_rank + 1)]
322+
param_slice = param.data[offset:offset + shard_size]
309323
assert param_slice.shape == loaded_weight.shape
310324

311325
param_slice.copy_(loaded_weight)
@@ -319,12 +333,24 @@ def load_weights(self,
319333
if weight_name not in name:
320334
continue
321335
param = state_dict[name.replace(weight_name, "gate_up_proj")]
322-
shard_size = param.shape[0] // 2
323-
loaded_weight = loaded_weight[
324-
shard_size * tensor_model_parallel_rank:shard_size *
325-
(tensor_model_parallel_rank + 1)]
326-
param_slice = param.data[shard_size * stride_id:shard_size *
327-
(stride_id + 1)]
336+
if "g_idx" in name:
337+
param.data.copy_(loaded_weight)
338+
is_gate_up_weight = True
339+
continue
340+
if any(key in name for key in ('qweight', 'qzeros', 'scales')):
341+
shard_size = param.shape[1] // 2
342+
loaded_weight = loaded_weight[:,
343+
shard_size * tensor_model_parallel_rank:shard_size *
344+
(tensor_model_parallel_rank + 1)]
345+
param_slice = param.data[:, shard_size * stride_id:shard_size *
346+
(stride_id + 1)]
347+
else:
348+
shard_size = param.shape[0] // 2
349+
loaded_weight = loaded_weight[
350+
shard_size * tensor_model_parallel_rank:shard_size *
351+
(tensor_model_parallel_rank + 1)]
352+
param_slice = param.data[shard_size * stride_id:shard_size *
353+
(stride_id + 1)]
328354
assert param_slice.shape == loaded_weight.shape
329355
param_slice.copy_(loaded_weight)
330356
is_gate_up_weight = True

vllm/model_executor/quantize.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Adapted from
2+
# https://github.com/PanQiWei/AutoGPTQ/blob/main/auto_gptq/modeling/_utils.py
3+
"""Utilities for quantizing models."""
4+
from typing import List, Dict
5+
6+
import torch.nn as nn
7+
import transformers
8+
9+
from auto_gptq.utils.import_utils import dynamically_import_QuantLinear
10+
11+
from vllm.model_executor.parallel_utils.tensor_parallel.layers import (
12+
ColumnParallelLinear,
13+
RowParallelLinear,
14+
)
15+
16+
def find_layers(
17+
module: nn.Module,
18+
layers: List[nn.Module] = None,
19+
name: str = ''
20+
) -> Dict[str, nn.Module]:
21+
if not layers:
22+
layers = [transformers.pytorch_utils.Conv1D, nn.Conv2d, nn.Linear,
23+
ColumnParallelLinear, RowParallelLinear]
24+
for layer in layers:
25+
if isinstance(module,layer):
26+
return {name: module}
27+
res = {}
28+
for name1, child in module.named_children():
29+
res.update(find_layers(child, layers=layers, name=name + '.' + name1 if name != '' else name1))
30+
return res
31+
32+
33+
def make_quant(
34+
module: nn.Module,
35+
names: List[str],
36+
bits: int,
37+
group_size: int,
38+
name: str = '',
39+
use_triton: bool = False,
40+
disable_exllama: bool = False,
41+
use_cuda_fp16: bool = True,
42+
desc_act: bool = False,
43+
trainable: bool = False
44+
) -> None:
45+
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits, disable_exllama=disable_exllama)
46+
47+
class QuantLinearWrapper(QuantLinear):
48+
def forward(self, *args, **kwargs):
49+
return super().forward(*args, **kwargs), None
50+
51+
if isinstance(module, QuantLinear):
52+
return
53+
for attr in dir(module):
54+
tmp = getattr(module, attr)
55+
name1 = name + '.' + attr if name != '' else attr
56+
if name1 in names:
57+
delattr(module, attr)
58+
if isinstance(tmp, nn.Linear):
59+
in_features = tmp.in_features
60+
out_features = tmp.out_features
61+
elif isinstance(tmp, nn.Conv2d):
62+
in_features = tmp.in_channels
63+
out_features = tmp.out_channels
64+
elif isinstance(tmp, transformers.pytorch_utils.Conv1D):
65+
in_features = tmp.weight.shape[0]
66+
out_features = tmp.weight.shape[1]
67+
elif isinstance(tmp, ColumnParallelLinear) or isinstance(tmp, RowParallelLinear):
68+
in_features = tmp.input_size
69+
out_features = tmp.output_size
70+
if (not(desc_act) or group_size == -1) and not use_triton:
71+
new_layer = QuantLinearWrapper(
72+
bits, group_size, in_features, out_features, True, use_cuda_fp16=use_cuda_fp16, trainable=trainable
73+
)
74+
else:
75+
new_layer = QuantLinearWrapper(bits, group_size, in_features, out_features, True, trainable=trainable)
76+
setattr(module, attr, new_layer)
77+
for name1, child in module.named_children():
78+
make_quant(
79+
child,
80+
names,
81+
bits,
82+
group_size,
83+
name + '.' + name1 if name != '' else name1,
84+
use_triton=use_triton,
85+
use_cuda_fp16=use_cuda_fp16,
86+
desc_act=desc_act,
87+
trainable=trainable,
88+
disable_exllama=disable_exllama,
89+
)

vllm/model_executor/weight_utils.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import numpy as np
1010
import torch
1111
from tqdm.auto import tqdm
12+
from safetensors.torch import safe_open
1213

1314

1415
class Disabledtqdm(tqdm):
@@ -33,7 +34,7 @@ def hf_model_weights_iterator(
3334
if not is_local:
3435
with lock:
3536
hf_folder = snapshot_download(model_name_or_path,
36-
allow_patterns="*.bin",
37+
allow_patterns=["*.bin", "*.safetensors"],
3738
cache_dir=cache_dir,
3839
tqdm_class=Disabledtqdm)
3940
else:
@@ -43,8 +44,19 @@ def hf_model_weights_iterator(
4344
x for x in glob.glob(os.path.join(hf_folder, "*.bin"))
4445
if not x.endswith("training_args.bin")
4546
]
47+
safetensor_files = [
48+
x for x in glob.glob(os.path.join(hf_folder, "*.safetensors"))
49+
]
4650

47-
if use_np_cache:
51+
# prioritize safetensor files
52+
if safetensor_files:
53+
for st_file in safetensor_files:
54+
with safe_open(st_file, framework="pt") as f:
55+
for name in f.keys():
56+
param = f.get_tensor(name)
57+
yield name, param
58+
torch.cuda.empty_cache()
59+
elif use_np_cache:
4860
# Convert the model weights from torch tensors to numpy arrays for
4961
# faster loading.
5062
np_folder = os.path.join(hf_folder, "np")

vllm/transformers_utils/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,4 @@ def get_config(model: str, trust_remote_code: bool) -> PretrainedConfig:
3030
if config.model_type in _CONFIG_REGISTRY:
3131
config_class = _CONFIG_REGISTRY[config.model_type]
3232
config = config_class.from_pretrained(model)
33-
return config
33+
return config

0 commit comments

Comments
 (0)