Skip to content

Commit a6cb519

Browse files
Vahe1994BlackSamorezGodofnothingefrantardalistarh
committed
Initial commit.
Co-authored-by: Andrei Panferov <[email protected]> Co-authored-by: Denis Kuznedelev <[email protected]> Co-authored-by: Elias Frantar <[email protected]> Co-authored-by: Dan Alistarh <[email protected]>
1 parent 25fe98f commit a6cb519

File tree

861 files changed

+21333
-1
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

861 files changed

+21333
-1
lines changed

.github/workflows/check-style.yaml

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
name: Check style
2+
3+
on:
4+
push:
5+
branches: [ main ]
6+
pull_request:
7+
8+
jobs:
9+
black:
10+
runs-on: ubuntu-latest
11+
steps:
12+
- uses: actions/checkout@v3
13+
- uses: psf/black@stable
14+
with:
15+
options: "--check --diff"
16+
version: "22.3.0"
17+
isort:
18+
runs-on: ubuntu-latest
19+
steps:
20+
- uses: actions/checkout@v3
21+
- uses: actions/setup-python@v3
22+
with:
23+
python-version: 3.8
24+
- uses: isort/isort-action@master
25+
with:
26+
isortVersion: "5.10.1"

README.md

Lines changed: 126 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,127 @@
11
# AQLM
2-
Official Pytorch repository for Extreme Compression of Large Language Models via Additive Quantization https://arxiv.org/pdf/2401.06118.pdf
2+
Official Pytorch repository for [Extreme Compression of Large Language Models via Additive Quantization](https://arxiv.org/pdf/2401.06118.pdf)
3+
4+
## Installation
5+
6+
### Packages
7+
8+
Install packages from `requirements.txt`:
9+
```bash
10+
pip install -r requirements.txt
11+
```
12+
13+
### Loading / caching datasets and tokenizer
14+
15+
The script will require downloading and caching locally the relevant tokenizer and the datasets.
16+
They will be saved in default Huggingface Datasets directory unless alternative location is provided by env variables.
17+
See [relevant Datasets documentation section](https://huggingface.co/docs/datasets/main/en/cache#cache-directory)
18+
## Models
19+
20+
This repository is expected to work with models of `LLaMA ` families so far.
21+
22+
## Data
23+
24+
For quantization with AQLM its is recommended to use the subset of the data model
25+
was trained on. I.e. for quantization of `LLaMA 2` models we recommend to use the subset
26+
of [RedPajama](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T-Sample) .The subset of Redpajama for 2048 and 4096 context length stored in `data` directory:
27+
* `red_pajama_n=1024_2048_context_length.pth`
28+
* `red_pajama_n=1024_4096_context_length.pth`
29+
30+
**Note** These subsets are already processed with the corresponding model tokenizer. Use for different model will lead to
31+
unexpected behavior.
32+
33+
### W&B logging
34+
35+
For the sake of convenience one can optionally log the data to `Weights and Biases` service (wandb).
36+
Run `pip install wandb` for W&B logging.
37+
Specify `$WANDB_ENTITY`, `$WANDB_PROJECT`, `$WANDB_NAME` environment variables prior to running experiments. use `--wandb` argument to enable logging
38+
# Launching
39+
40+
### GPU and RAM requirements
41+
This code was developed and tested using a several A100 GPU with 80GB GPU RAM.
42+
`--offload activations` option, reduce VRAM usage.
43+
For `Language Model Evaluation Harness` evaluation one needs to have enough memory to load whole model
44+
on one or several devices + activation tensors.
45+
46+
### Model downloading
47+
The code requires the LLaMA model to be downloaded in Huggingface format and saved locally. The scripts below assume that `$TRANSFORMERS_CACHE` variable points to the Huggingface Transformers cache folder.
48+
49+
### Perplexity benchmarks:
50+
This script compresses the model and then tests its performance in terms of perplexity using WikiText2, C4, and Penn Treebank datasets.
51+
52+
The command to launch the script should look like this:
53+
54+
```
55+
export MODEL_PATH=<PATH_TO_MODEL_DIR>
56+
export DATASET=<INSERT DATASET NAME OR PATH TO CUSTOM DATA>
57+
58+
python main.py $MODEL_PATH $DATASET \
59+
--num_codebooks=2 \
60+
--
61+
--relative_mse_tolerance=0.01 \
62+
--go_relative_mse_tolerance=0.001 \
63+
--nsamples=1024 \
64+
--nbits_per_codebook=15 \
65+
--in_group_size=8 \
66+
--scale_nbits=0 \
67+
--local_batch_size=4 \
68+
--save="save_path"\
69+
--batch_size=32 \
70+
--wandb
71+
```
72+
73+
Note the launch arguments:
74+
- `<PATH_TO_MODEL_DIR>` - path to model folder, which contains `config.json `
75+
- `one of [c4, ptb, wikitext2, pajama, refinedweb, none]` -- name of dataset to use for compression, or path to an alternative preprocessed and tokenized dataset.
76+
- `--num_codebooks` - #Number of codebooks per layer
77+
- `--batch_size` - Size of sequences fot fine-tuning the layer (GO), globally across all GPUs
78+
- `--local_batch_size` - Per-device and per-forward-pass batch size used to accumulate global --batch_size
79+
- `--nsamples` - Number of calibration data samples.If None take all calibration data.
80+
- `--relative_mse_tolerance`- Stop training when (current_epoch_mse / previous_epoch_mse) > (1 - relative_mse_tolerance)
81+
- `--in_group_size` - How many input features are quantized together
82+
- `--nbits_per_codebook` - Codebook size. Each codebook will contain 2 ** nbits_per_codebook vectors
83+
- `--scale_nbits` - Number of bits dedicated to the learnable group-wise scale.0 will use row-wise scales
84+
- `--offload activations` -- moves activations to RAM when not used. Reduces VRAM usage while slowing work by ~10%.
85+
run `python main.py --help` for more details on command line arguments, including compression parameters.
86+
- `--save --load` -- path to save/load quantized model.
87+
- `--wandb` - log to wandb
88+
89+
### LM Evaluation Harness benchmark.
90+
91+
To perform zero-shot evaluation, we use [Language Model Evaluation Harness](https://github.com/EleutherAI/lm-evaluation-harness) framework with slight modifications. This repository contains a copy of LM Evaluation Harness repo from early 2023 in `lm-eval-harness` folder.
92+
#### Installation
93+
Before running the code make sure that you have all the requirements and dependencies of `lm-eval-harness` installed. To install them run:
94+
```
95+
pip install -r lm-evaluation-harness/requirements.txt
96+
```
97+
#### Execution
98+
99+
The main script launching the evaluation procedure is `lmeval.py` .
100+
101+
102+
```
103+
export MODEL_PATH=<INSERT PATH_TO_MODEL_DIR>
104+
export DATASET=<INSERT DATASET NAME OR PATH TO CUSTOM DATA>
105+
106+
python lmeval.py \
107+
--model hf-causal \
108+
--model_args pretrained=$MODEL_PATH,dtype=float16,use_accelerate=True \
109+
--load $QUANTZED_MODEL \
110+
--tasks winogrande,piqa,hellaswag,arc_easy,arc_challenge \
111+
--batch_size 1
112+
```
113+
114+
## Contributing
115+
We use black and isort for all pull requests. Before committing your code run black . && isort . .
116+
117+
## Citation
118+
```
119+
@misc{egiazarian2024extreme,
120+
title={Extreme Compression of Large Language Models via Additive Quantization},
121+
author={Vage Egiazarian and Andrei Panferov and Denis Kuznedelev and Elias Frantar and Artem Babenko and Dan Alistarh},
122+
year={2024},
123+
eprint={2401.06118},
124+
archivePrefix={arXiv},
125+
primaryClass={cs.LG}
126+
}
127+
```

__init__.py

Whitespace-only changes.

aq_engine.py

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
from __future__ import annotations
2+
3+
import math
4+
import random
5+
from argparse import Namespace
6+
from typing import Optional, Sequence, Union
7+
8+
import torch
9+
import torch.nn as nn
10+
from torch.nn.parallel.scatter_gather import Gather
11+
12+
from src.aq import QuantizedWeight
13+
from src.utils import ellipsis
14+
15+
16+
class AQEngine(nn.Module):
17+
"""A wrapper class that runs AQ training for a single linear layer. All the important math is in aq.py"""
18+
19+
def __init__(self, layer: nn.Linear, accumultor_dtype: torch.dtype = torch.float64):
20+
super().__init__()
21+
self.layer = layer
22+
self.device = layer.weight.device
23+
self.columns = self.layer.weight.data.shape[1]
24+
self.register_buffer(
25+
"XTX", torch.zeros((self.columns, self.columns), dtype=accumultor_dtype, device=self.device)
26+
)
27+
self.quantized_weight: Optional[QuantizedWeight] = None
28+
self.nsamples = 0
29+
30+
@torch.no_grad()
31+
def add_batch(self, inp: torch.Tensor):
32+
"""Accumulate a minibatch of layer inputs and update the X.T @ X (aka half hessian)"""
33+
assert self.XTX is not None, "Already ran quantization; cannot add more data batches"
34+
if len(inp.shape) == 3:
35+
inp = inp.reshape((-1, inp.shape[-1]))
36+
tmp = inp.shape[0]
37+
inp = inp.t()
38+
39+
self.XTX *= self.nsamples / (self.nsamples + tmp)
40+
self.nsamples += tmp
41+
inp = math.sqrt(1 / self.nsamples) * inp.to(self.XTX.dtype)
42+
self.XTX += inp.matmul(inp.t())
43+
44+
@torch.enable_grad()
45+
def quantize(self, *, args: Namespace, verbose: bool = True) -> QuantizedWeight:
46+
"""create a QuantizedLinear with specified args based on the collected hessian (XTX) data"""
47+
assert isinstance(args.devices, (list, tuple)) and len(args.devices) >= 1, f"Found devices = {args.devices}"
48+
assert args.devices[0] == self.device, (args.devices[0], self.XTX.device)
49+
self.quantized_weight = QuantizedWeight(
50+
XTX=self.XTX.to(device=self.device, dtype=torch.float32),
51+
reference_weight=self.layer.weight.detach().to(device=self.device, dtype=torch.float32),
52+
out_group_size=args.out_group_size,
53+
in_group_size=args.in_group_size,
54+
num_codebooks=args.num_codebooks,
55+
nbits_per_codebook=args.nbits_per_codebook,
56+
codebook_value_nbits=args.codebook_value_nbits,
57+
codebook_value_num_groups=args.codebook_value_num_groups,
58+
scale_nbits=args.scale_nbits,
59+
rrr_rank=args.rrr_rank,
60+
max_iter=args.init_max_iter,
61+
max_points_per_centroid=args.max_points_per_centroid,
62+
devices=args.devices,
63+
verbose=True,
64+
)
65+
66+
differentiable_parameters = nn.ParameterDict(
67+
{name: param for name, param in self.quantized_weight.named_parameters() if param.requires_grad}
68+
)
69+
opt = torch.optim.Adam(differentiable_parameters.values(), lr=args.lr, betas=(0.0, 0.95), amsgrad=True)
70+
71+
replicas = None
72+
if len(args.devices) > 1:
73+
replicas = torch.nn.parallel.replicate(self, args.devices)
74+
replicas[0] = self
75+
76+
previous_best_loss = float("inf") # for early stopping
77+
for epoch in range(args.max_epochs):
78+
# train codebooks and scales
79+
for step in range(args.steps_per_epoch):
80+
if len(args.devices) == 1:
81+
loss = self._compute_mse()
82+
else:
83+
loss = self._compute_mse_parallel(args.devices, replicas, differentiable_parameters)
84+
85+
if not torch.isfinite(loss).item():
86+
raise ValueError(f"Quantization loss is {loss}")
87+
if step == 0 and args.relative_mse_tolerance is not None:
88+
if loss.item() / previous_best_loss > (1.0 - args.relative_mse_tolerance):
89+
return self.quantized_weight # early stopping; no updates after last epoch's beam search
90+
previous_best_loss = min(previous_best_loss, loss.item())
91+
92+
opt.zero_grad()
93+
loss.backward()
94+
opt.step()
95+
if verbose and (epoch * args.steps_per_epoch + step) % args.print_frequency == 0:
96+
print(f"epoch={epoch}\tstep={step}\tloss={loss.item():.10f}\t")
97+
98+
# search for better codes (cluster indices)
99+
seed = random.getrandbits(256)
100+
self.beam_search_update_codes_(
101+
args.devices,
102+
replicas,
103+
differentiable_parameters,
104+
seed=seed,
105+
beam_size=args.beam_size,
106+
sparsity_regularizer=args.sparsity_regularizer,
107+
verbose=True,
108+
)
109+
return self.quantized_weight
110+
111+
def _compute_mse(self, selection: Union[slice, ellipsis] = ...) -> torch.Tensor:
112+
"""
113+
Compute the activation MSE error = ||X @ quantized_weight - X @ reference_weight||^2
114+
Use the square-of-difference formula to avoid materializing per-batch predictions
115+
:param selection: By default, compute MSE normally. If selection is specified, this method will instead
116+
compute MSE over a portion of output channels that align with the selected out_groups (for parallelism)
117+
The indices / slices must correspond to output channels (if out_group_size==1) or groups (if > 1).
118+
Formally, the indices must be in range [ 0 , self.out_features // self.out_group_size )
119+
"""
120+
assert self.quantized_weight is not None, "must be called inside / after AQUtil.quantize"
121+
quantized_weight = self.quantized_weight(selection)
122+
123+
if isinstance(selection, ellipsis):
124+
reference_weight = self.layer.weight.detach().to(quantized_weight.dtype)
125+
else:
126+
assert isinstance(selection, slice)
127+
out_channel_selection = slice(
128+
selection.start * self.quantized_weight.out_group_size,
129+
selection.stop * self.quantized_weight.out_group_size,
130+
)
131+
132+
reference_weight = self.layer.weight.detach()[out_channel_selection].to(quantized_weight.dtype)
133+
delta_weight = (quantized_weight - reference_weight).to(self.XTX.dtype)
134+
return (delta_weight @ self.XTX).flatten() @ delta_weight.flatten() / self.quantized_weight.out_features
135+
136+
def _substitute_and_compute_mse(self, overrides: nn.ParameterDict, selection: slice) -> torch.Tensor:
137+
"""Utility for parallelism: replace the specified parameters of self.quantized_weight, then compute MSE"""
138+
for param_name, param_value in overrides.items():
139+
replace_parameter_(self.quantized_weight, param_name, param_value)
140+
return self._compute_mse(selection)
141+
142+
def _compute_mse_parallel(
143+
self, devices: Sequence[torch.device], replicas: Sequence[AQEngine], parameters_to_replicate: nn.ParameterDict
144+
) -> torch.Tensor:
145+
"""Compute MSE in parallel over output channels"""
146+
replicated_parameters = torch.nn.parallel.replicate(parameters_to_replicate, devices, detach=False)
147+
num_output_groups = self.quantized_weight.out_features // self.quantized_weight.out_group_size
148+
shard_size = (num_output_groups - 1) // len(devices) + 1
149+
active_slices_by_replica = [
150+
slice(i * shard_size, min((i + 1) * shard_size, num_output_groups)) for i in range(len(devices))
151+
]
152+
funcs_by_replica = [replica._substitute_and_compute_mse for replica in replicas]
153+
inputs_by_replica = [(dict(), active_slices_by_replica[0])] # no overrides needed for 0-th replica
154+
for i in range(1, len(devices)):
155+
inputs_by_replica.append((replicated_parameters[i], active_slices_by_replica[i]))
156+
mse_components = torch.nn.parallel.parallel_apply(funcs_by_replica, inputs_by_replica, devices=devices)
157+
return Gather.apply(devices[0], 0, *(mse.view(1) for mse in mse_components)).sum()
158+
159+
def _substitute_and_beam_search(self, overrides: nn.ParameterDict, selection: slice, **kwargs) -> torch.Tensor:
160+
"""Utility for parallelism: replace the specified parameters of self.quantized_weight, then run beam search"""
161+
dtype = self.quantized_weight.codebooks.dtype
162+
for param_name, param_value in overrides.items():
163+
replace_parameter_(self.quantized_weight, param_name, param_value)
164+
out_channel_selection = slice(
165+
selection.start * self.quantized_weight.out_group_size,
166+
selection.stop * self.quantized_weight.out_group_size,
167+
)
168+
reference_weight = self.layer.weight.detach()[out_channel_selection].to(dtype)
169+
return self.quantized_weight.beam_search_update_codes_(
170+
self.XTX.to(dtype), reference_weight, selection=selection, **kwargs
171+
).clone()
172+
173+
@torch.no_grad()
174+
def beam_search_update_codes_(
175+
self,
176+
devices: Sequence[torch.device],
177+
replicas: Sequence[AQEngine],
178+
parameters_to_replicate: nn.ParameterDict,
179+
seed: Optional[int] = None,
180+
**kwargs,
181+
):
182+
"""Update self.quantized_weight.codes in-place via beam search"""
183+
if len(devices) == 1: # single device
184+
assert replicas is None
185+
dtype = self.quantized_weight.codebooks.dtype
186+
self.quantized_weight.beam_search_update_codes_(
187+
self.XTX.to(dtype), self.layer.weight.detach().to(dtype), dim_rng=random.Random(seed), **kwargs
188+
)
189+
else:
190+
assert replicas[0] is self
191+
replicated_parameters = torch.nn.parallel.replicate(parameters_to_replicate, devices)
192+
num_output_groups = self.quantized_weight.out_features // self.quantized_weight.out_group_size
193+
shard_size = (num_output_groups - 1) // len(devices) + 1
194+
active_slices_by_replica = [
195+
slice(i * shard_size, min((i + 1) * shard_size, num_output_groups)) for i in range(len(devices))
196+
]
197+
198+
funcs_by_replica = [replica._substitute_and_beam_search for replica in replicas]
199+
inputs_by_replica = [(dict(), active_slices_by_replica[0])]
200+
for i in range(1, len(devices)):
201+
inputs_by_replica.append((replicated_parameters[i], active_slices_by_replica[i]))
202+
kwargs_by_replica = [dict(kwargs, dim_rng=random.Random(seed)) for _ in range(len(devices))]
203+
new_code_parts_by_replica = torch.nn.parallel.parallel_apply(
204+
funcs_by_replica, inputs_by_replica, kwargs_by_replica, devices=devices
205+
)
206+
# gather all code parts and assign them to each replica
207+
for device, replica in zip(devices, replicas):
208+
replica.quantized_weight.codes[...] = Gather.apply(device, 0, *new_code_parts_by_replica)
209+
210+
211+
def replace_parameter_(module: nn.Module, name: str, new_value: torch.Tensor):
212+
"""A hacky way to substitute an already registered parameter with a non-parameter tensor. Breaks future use."""
213+
if name in module._parameters:
214+
module._parameters[name] = new_value
215+
else:
216+
setattr(module, name, new_value)
47.9 MB
Binary file not shown.
83.3 MB
Binary file not shown.

data/refined_web_n=128.pth

5.29 MB
Binary file not shown.

0 commit comments

Comments
 (0)