Skip to content

Commit a441a3f

Browse files
justheuristicdvmazurgalqiwi
authored
merge PV-Tuning into AQLM main (#110)
PV-tuning --------- Co-authored-by: Denis Mazur <[email protected]> Co-authored-by: Vladimir Malinovskii <[email protected]>
1 parent 559a366 commit a441a3f

16 files changed

+3347
-734
lines changed

aq_engine.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ def quantize(self, *, args: Namespace, verbose: bool = True) -> QuantizedWeight:
4747
assert isinstance(args.devices, (list, tuple)) and len(args.devices) >= 1, f"Found devices = {args.devices}"
4848
assert args.devices[0] == self.device, (args.devices[0], self.XTX.device)
4949
self.quantized_weight = QuantizedWeight(
50-
XTX=self.XTX.to(device=self.device, dtype=torch.float32),
5150
reference_weight=self.layer.weight.detach().to(device=self.device, dtype=torch.float32),
5251
out_group_size=args.out_group_size,
5352
in_group_size=args.in_group_size,
@@ -165,7 +164,7 @@ def _replace_and_beam_search(self, params_to_replace: nn.ParameterDict, selectio
165164
)
166165
reference_weight = self.layer.weight.detach()[out_channel_selection].to(dtype)
167166
return self.quantized_weight.beam_search_update_codes_(
168-
self.XTX.to(dtype), reference_weight, selection=selection, **kwargs
167+
XTX=self.XTX.to(dtype), reference_weight=reference_weight, selection=selection, **kwargs
169168
).clone()
170169

171170
@torch.no_grad()
@@ -177,12 +176,15 @@ def beam_search_update_codes_(
177176
seed: Optional[int] = None,
178177
**kwargs,
179178
):
180-
"""Update self.quantized_weight.codes in-place via beam search"""
179+
"""Update quantized_weight codes in-place via beam search"""
181180
if len(devices) == 1: # single device
182181
assert replicas is None
183182
dtype = self.quantized_weight.codebooks.dtype
184183
self.quantized_weight.beam_search_update_codes_(
185-
self.XTX.to(dtype), self.layer.weight.detach().to(dtype), dim_rng=random.Random(seed), **kwargs
184+
XTX=self.XTX.to(dtype),
185+
reference_weight=self.layer.weight.detach().to(dtype),
186+
dim_rng=random.Random(seed),
187+
**kwargs,
186188
)
187189
else:
188190
assert replicas[0] is self
@@ -203,7 +205,7 @@ def beam_search_update_codes_(
203205
)
204206
# gather all code parts and assign them to each replica
205207
for device, replica in zip(devices, replicas):
206-
replica.quantized_weight.codes[...] = Gather.apply(device, 0, *new_code_parts_by_replica)
208+
replica.quantized_weight.set_codes(Gather.apply(device, 0, *new_code_parts_by_replica))
207209

208210

209211
def replace_parameter_(module: nn.Module, name: str, new_value: torch.Tensor):

convert_legacy_model_format.py

+214
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
"""
2+
This abomination converts between one of several quantized model formats to the same format as returned by main.py .
3+
This code exists because we failed to produce a single data format for quantized model.
4+
We should eventually switch to saving all models in the same data format. Once we do, this file should be deleted.
5+
"""
6+
import argparse
7+
import os
8+
import warnings
9+
from copy import deepcopy
10+
11+
import torch
12+
import transformers.models
13+
from torch import nn
14+
15+
from src.aq import QuantizedLinear, QuantizedWeight
16+
from src.modelutils import get_model, save_quantized_model
17+
from src.utils import is_signed
18+
19+
20+
def load_quantized_model_with_old_pickle(base_model_name: str, quantized_model_name: str, **kwargs):
21+
"""Hacky way to allow compatibility between old *pickled* layers and new transformers"""
22+
# because patching it for the fourth time is better than writing a proper saver once >.<
23+
import transformers.activations
24+
25+
if not hasattr(transformers.activations, "SiLUActivation"):
26+
transformers.activations.SiLUActivation = deepcopy(torch.nn.SiLU)
27+
transformers.activations.SiLUActivation.inplace = False
28+
# https://github.com/huggingface/transformers/issues/28496
29+
if not hasattr(transformers.models.llama.modeling_llama.LlamaAttention, "attention_dropout"):
30+
transformers.models.llama.modeling_llama.LlamaAttention.attention_dropout = 0
31+
quantized_model = get_model(base_model_name, None, **kwargs)
32+
quantized_model_src = get_model(base_model_name, quantized_model_name, **kwargs)
33+
for module in quantized_model_src.modules():
34+
if isinstance(module, QuantizedWeight) and not hasattr(module, "codes_storage"):
35+
module.codes_storage = None # backwards compatibility with older pickled snapshots
36+
37+
lut = {}
38+
for name, module in quantized_model_src.named_modules():
39+
for child_name, child_module in module.named_children():
40+
if isinstance(child_module, QuantizedWeight):
41+
lut[name + "." + child_name] = child_module
42+
print(f"found {len(lut)} quantized weight matrices")
43+
for name, module in quantized_model.named_modules():
44+
for child_name, child_module in module.named_children():
45+
if name + "." + child_name + ".quantized_weight" in lut:
46+
quantized_weight = lut.pop(name + "." + child_name + ".quantized_weight")
47+
assert isinstance(child_module, nn.Linear)
48+
setattr(module, child_name, QuantizedLinear(quantized_weight, bias=child_module.bias))
49+
assert not lut, list(lut.keys())
50+
quantized_model.load_state_dict(quantized_model_src.state_dict())
51+
warnings.warn("You should be ashamed of yourself.")
52+
return quantized_model
53+
54+
55+
import functools
56+
57+
58+
def rsetattr(obj, attr, val):
59+
pre, _, post = attr.rpartition(".")
60+
return setattr(rgetattr(obj, pre) if pre else obj, post, val)
61+
62+
63+
def rgetattr(obj, attr, *args):
64+
def _getattr(obj, attr):
65+
return getattr(obj, attr, *args)
66+
67+
return functools.reduce(_getattr, [obj] + attr.split("."))
68+
69+
70+
def load_quantized_model_from_fdsp_checkpoint(base_model_name: str, fsdp_checkpoint_path: str, **kwargs):
71+
original_model = get_model(base_model_name, None, **kwargs)
72+
73+
state_filenames = os.listdir(fsdp_checkpoint_path)
74+
75+
non_quant_fname = "non_quantized_state_dict.pth"
76+
non_quant_path = os.path.join(fsdp_checkpoint_path, non_quant_fname)
77+
non_quant_states = torch.load(non_quant_path)
78+
79+
incomp_keys = original_model.load_state_dict(non_quant_states, strict=False)
80+
assert not incomp_keys.unexpected_keys
81+
82+
missing_keys = list()
83+
for module_name, module in original_model.named_modules():
84+
if not isinstance(module, nn.Linear):
85+
continue
86+
87+
assert not module.bias
88+
state_fname = f"{module_name}.weight.pth"
89+
90+
if state_fname not in state_filenames:
91+
missing_keys.append(module_name)
92+
continue
93+
94+
state_path = os.path.join(fsdp_checkpoint_path, state_fname)
95+
quantized_weight = torch.load(state_path, map_location="cpu")
96+
quantized_linear = QuantizedLinear(quantized_weight, bias=None)
97+
rsetattr(original_model, module_name, quantized_linear)
98+
99+
return original_model
100+
101+
102+
def main():
103+
parser = argparse.ArgumentParser(add_help=True)
104+
parser.add_argument(
105+
"--base_model",
106+
type=str,
107+
required=True,
108+
help="path or name of the teacher model",
109+
)
110+
parser.add_argument(
111+
"--quantized_model",
112+
type=str,
113+
required=True,
114+
help="path to quantized model",
115+
)
116+
parser.add_argument(
117+
"--load_dtype",
118+
type=str,
119+
default="auto",
120+
choices=["auto", "float16", "float32", "bfloat16"],
121+
help="dtype to load the model in",
122+
)
123+
parser.add_argument(
124+
"--code_dtype",
125+
type=str,
126+
default=None,
127+
help="if specified, cast quantized layers' codes to this dtype; default = keep loaded dtype",
128+
)
129+
parser.add_argument(
130+
"--p_finetuned_state_dict",
131+
type=str,
132+
default=None,
133+
help="path to quantized model state dict saved by the old FSDP finetuning code",
134+
)
135+
parser.add_argument(
136+
"--pv_fsdp_dir",
137+
type=str,
138+
default=None,
139+
help="path to quantized model state dict saved by the old FSDP finetuning code",
140+
)
141+
parser.add_argument(
142+
"--monkeypatch_old_pickle",
143+
action="store_true",
144+
help="If set, load quantized_model in a hacky way that allows pickled models with older transformers/torch.",
145+
)
146+
parser.add_argument(
147+
"--attn_implementation",
148+
type=str,
149+
default=None,
150+
help="Attention implementation for both teacher and student models: eager, sdpa, or flash_attention_2",
151+
)
152+
parser.add_argument(
153+
"--trust_remote_code",
154+
action="store_true",
155+
help="Whether to trust remote code when loading base model.",
156+
)
157+
parser.add_argument("--save", type=str, required=True, help="Save the converted quantized model here")
158+
159+
args = parser.parse_args()
160+
assert args.p_finetuned_state_dict or args.pv_fsdp_dir, "either one of those must be specified"
161+
print(f"{args.p_finetuned_state_dict=}, {args.pv_fsdp_dir=}")
162+
assert (args.p_finetuned_state_dict is not None) != (args.pv_fsdp_dir is not None)
163+
164+
args.load_dtype = getattr(torch, args.load_dtype) if args.load_dtype != "auto" else "auto"
165+
args.code_dtype = getattr(torch, args.code_dtype) if args.code_dtype is not None else None
166+
167+
if not args.monkeypatch_old_pickle:
168+
quantized_model = get_model(
169+
args.base_model,
170+
args.quantized_model,
171+
dtype=args.load_dtype,
172+
trust_remote_code=args.trust_remote_code,
173+
attn_implementation=args.attn_implementation,
174+
)
175+
elif args.p_finetuned_state_dict:
176+
quantized_model = load_quantized_model_with_old_pickle(
177+
args.base_model,
178+
args.quantized_model,
179+
dtype=args.load_dtype,
180+
trust_remote_code=args.trust_remote_code,
181+
attn_implementation=args.attn_implementation,
182+
)
183+
elif args.pv_fsdp_dir:
184+
quantized_model = load_quantized_model_from_fdsp_checkpoint(
185+
args.base_model,
186+
args.pv_fsdp_dir,
187+
dtype=args.load_dtype,
188+
trust_remote_code=args.trust_remote_code,
189+
)
190+
191+
for module in quantized_model.modules():
192+
if isinstance(module, QuantizedWeight):
193+
if not hasattr(module, "codes_storage"):
194+
module.codes_storage = None
195+
if module.codes is None:
196+
module.unwrap_codes_()
197+
assert module.codes is not None
198+
if args.code_dtype is not None:
199+
assert module.nbits_per_codebook <= torch.iinfo(args.code_dtype).bits - is_signed(args.code_dtype)
200+
module.codes = nn.Parameter(module.codes.to(args.code_dtype), requires_grad=module.codes.requires_grad)
201+
202+
if args.p_finetuned_state_dict is not None:
203+
state_dict = torch.load(args.p_finetuned_state_dict, map_location="cpu")
204+
state_dict = {k: v for k, v in state_dict.items() if not k.endswith(".codes_storage.data")}
205+
status = quantized_model.load_state_dict(state_dict, strict=False)
206+
assert all(key.endswith("codes") for key in status.missing_keys)
207+
assert not status.unexpected_keys
208+
del state_dict, status # note: in this case, it is okay not to load codes since P step does not change them
209+
210+
save_quantized_model(quantized_model, args.save)
211+
212+
213+
if __name__ == "__main__":
214+
main()

0 commit comments

Comments
 (0)