forked from lm-sys/FastChat
-
Notifications
You must be signed in to change notification settings - Fork 44
/
Copy pathload_gptq_model.py
64 lines (51 loc) · 2.07 KB
/
load_gptq_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import sys
from pathlib import Path
import torch
import transformers
from transformers import AutoConfig, AutoModelForCausalLM
sys.path.insert(0, str(Path("repositories/GPTQ-for-LLaMa")))
from modelutils import find_layers
from quant import make_quant
def load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exclude_layers=['lm_head'], kernel_switch_threshold=128):
config = AutoConfig.from_pretrained(model)
def noop(*args, **kwargs):
pass
torch.nn.init.kaiming_uniform_ = noop
torch.nn.init.uniform_ = noop
torch.nn.init.normal_ = noop
torch.set_default_dtype(torch.half)
transformers.modeling_utils._init_weights = False
torch.set_default_dtype(torch.half)
model = AutoModelForCausalLM.from_config(config)
torch.set_default_dtype(torch.float)
model = model.eval()
layers = find_layers(model)
for name in exclude_layers:
if name in layers:
del layers[name]
make_quant(model, layers, wbits, groupsize, faster=faster_kernel, kernel_switch_threshold=kernel_switch_threshold)
del layers
print('Loading model ...')
if checkpoint.endswith('.safetensors'):
from safetensors.torch import load_file as safe_load
model.load_state_dict(safe_load(checkpoint))
else:
model.load_state_dict(torch.load(checkpoint))
model.seqlen = 2048
print('Done.')
return model
def load_quantized(model_name, wbits=4, groupsize=128, threshold=128):
model_name = model_name.replace('/', '_')
path_to_model = Path(f'./models/{model_name}')
found_pts = list(path_to_model.glob("*.pt"))
found_safetensors = list(path_to_model.glob("*.safetensors"))
pt_path = None
if len(found_pts) == 1:
pt_path = found_pts[0]
elif len(found_safetensors) == 1:
pt_path = found_safetensors[0]
if not pt_path:
print("Could not find the quantized model in .pt or .safetensors format, exiting...")
exit()
model = load_quant(str(path_to_model), str(pt_path), wbits, groupsize, kernel_switch_threshold=threshold)
return model