|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import math |
| 4 | +import random |
| 5 | +from argparse import Namespace |
| 6 | +from typing import Optional, Sequence, Union |
| 7 | + |
| 8 | +import torch |
| 9 | +import torch.nn as nn |
| 10 | +from torch.nn.parallel.scatter_gather import Gather |
| 11 | + |
| 12 | +from src.aq import QuantizedWeight |
| 13 | +from src.utils import ellipsis |
| 14 | + |
| 15 | + |
| 16 | +class AQEngine(nn.Module): |
| 17 | + """A wrapper class that runs AQ training for a single linear layer. All the important math is in aq.py""" |
| 18 | + |
| 19 | + def __init__(self, layer: nn.Linear, accumultor_dtype: torch.dtype = torch.float64): |
| 20 | + super().__init__() |
| 21 | + self.layer = layer |
| 22 | + self.device = layer.weight.device |
| 23 | + self.columns = self.layer.weight.data.shape[1] |
| 24 | + self.register_buffer( |
| 25 | + "XTX", torch.zeros((self.columns, self.columns), dtype=accumultor_dtype, device=self.device) |
| 26 | + ) |
| 27 | + self.quantized_weight: Optional[QuantizedWeight] = None |
| 28 | + self.nsamples = 0 |
| 29 | + |
| 30 | + @torch.no_grad() |
| 31 | + def add_batch(self, inp: torch.Tensor): |
| 32 | + """Accumulate a minibatch of layer inputs and update the X.T @ X (aka half hessian)""" |
| 33 | + assert self.XTX is not None, "Already ran quantization; cannot add more data batches" |
| 34 | + if len(inp.shape) == 3: |
| 35 | + inp = inp.reshape((-1, inp.shape[-1])) |
| 36 | + tmp = inp.shape[0] |
| 37 | + inp = inp.t() |
| 38 | + |
| 39 | + self.XTX *= self.nsamples / (self.nsamples + tmp) |
| 40 | + self.nsamples += tmp |
| 41 | + inp = math.sqrt(1 / self.nsamples) * inp.to(self.XTX.dtype) |
| 42 | + self.XTX += inp.matmul(inp.t()) |
| 43 | + |
| 44 | + @torch.enable_grad() |
| 45 | + def quantize(self, *, args: Namespace, verbose: bool = True) -> QuantizedWeight: |
| 46 | + """create a QuantizedLinear with specified args based on the collected hessian (XTX) data""" |
| 47 | + assert isinstance(args.devices, (list, tuple)) and len(args.devices) >= 1, f"Found devices = {args.devices}" |
| 48 | + assert args.devices[0] == self.device, (args.devices[0], self.XTX.device) |
| 49 | + self.quantized_weight = QuantizedWeight( |
| 50 | + XTX=self.XTX.to(device=self.device, dtype=torch.float32), |
| 51 | + reference_weight=self.layer.weight.detach().to(device=self.device, dtype=torch.float32), |
| 52 | + out_group_size=args.out_group_size, |
| 53 | + in_group_size=args.in_group_size, |
| 54 | + num_codebooks=args.num_codebooks, |
| 55 | + nbits_per_codebook=args.nbits_per_codebook, |
| 56 | + codebook_value_nbits=args.codebook_value_nbits, |
| 57 | + codebook_value_num_groups=args.codebook_value_num_groups, |
| 58 | + scale_nbits=args.scale_nbits, |
| 59 | + rrr_rank=args.rrr_rank, |
| 60 | + max_iter=args.init_max_iter, |
| 61 | + max_points_per_centroid=args.max_points_per_centroid, |
| 62 | + devices=args.devices, |
| 63 | + verbose=True, |
| 64 | + ) |
| 65 | + |
| 66 | + differentiable_parameters = nn.ParameterDict( |
| 67 | + {name: param for name, param in self.quantized_weight.named_parameters() if param.requires_grad} |
| 68 | + ) |
| 69 | + opt = torch.optim.Adam(differentiable_parameters.values(), lr=args.lr, betas=(0.0, 0.95), amsgrad=True) |
| 70 | + |
| 71 | + replicas = None |
| 72 | + if len(args.devices) > 1: |
| 73 | + replicas = torch.nn.parallel.replicate(self, args.devices) |
| 74 | + replicas[0] = self |
| 75 | + |
| 76 | + previous_best_loss = float("inf") # for early stopping |
| 77 | + for epoch in range(args.max_epochs): |
| 78 | + # train codebooks and scales |
| 79 | + for step in range(args.steps_per_epoch): |
| 80 | + if len(args.devices) == 1: |
| 81 | + loss = self._compute_mse() |
| 82 | + else: |
| 83 | + loss = self._compute_mse_parallel(args.devices, replicas, differentiable_parameters) |
| 84 | + |
| 85 | + if not torch.isfinite(loss).item(): |
| 86 | + raise ValueError(f"Quantization loss is {loss}") |
| 87 | + if step == 0 and args.relative_mse_tolerance is not None: |
| 88 | + if loss.item() / previous_best_loss > (1.0 - args.relative_mse_tolerance): |
| 89 | + return self.quantized_weight # early stopping; no updates after last epoch's beam search |
| 90 | + previous_best_loss = min(previous_best_loss, loss.item()) |
| 91 | + |
| 92 | + opt.zero_grad() |
| 93 | + loss.backward() |
| 94 | + opt.step() |
| 95 | + if verbose and (epoch * args.steps_per_epoch + step) % args.print_frequency == 0: |
| 96 | + print(f"epoch={epoch}\tstep={step}\tloss={loss.item():.10f}\t") |
| 97 | + |
| 98 | + # search for better codes (cluster indices) |
| 99 | + seed = random.getrandbits(256) |
| 100 | + self.beam_search_update_codes_( |
| 101 | + args.devices, |
| 102 | + replicas, |
| 103 | + differentiable_parameters, |
| 104 | + seed=seed, |
| 105 | + beam_size=args.beam_size, |
| 106 | + sparsity_regularizer=args.sparsity_regularizer, |
| 107 | + verbose=True, |
| 108 | + ) |
| 109 | + return self.quantized_weight |
| 110 | + |
| 111 | + def _compute_mse(self, selection: Union[slice, ellipsis] = ...) -> torch.Tensor: |
| 112 | + """ |
| 113 | + Compute the activation MSE error = ||X @ quantized_weight - X @ reference_weight||^2 |
| 114 | + Use the square-of-difference formula to avoid materializing per-batch predictions |
| 115 | + :param selection: By default, compute MSE normally. If selection is specified, this method will instead |
| 116 | + compute MSE over a portion of output channels that align with the selected out_groups (for parallelism) |
| 117 | + The indices / slices must correspond to output channels (if out_group_size==1) or groups (if > 1). |
| 118 | + Formally, the indices must be in range [ 0 , self.out_features // self.out_group_size ) |
| 119 | + """ |
| 120 | + assert self.quantized_weight is not None, "must be called inside / after AQUtil.quantize" |
| 121 | + quantized_weight = self.quantized_weight(selection) |
| 122 | + |
| 123 | + if isinstance(selection, ellipsis): |
| 124 | + reference_weight = self.layer.weight.detach().to(quantized_weight.dtype) |
| 125 | + else: |
| 126 | + assert isinstance(selection, slice) |
| 127 | + out_channel_selection = slice( |
| 128 | + selection.start * self.quantized_weight.out_group_size, |
| 129 | + selection.stop * self.quantized_weight.out_group_size, |
| 130 | + ) |
| 131 | + |
| 132 | + reference_weight = self.layer.weight.detach()[out_channel_selection].to(quantized_weight.dtype) |
| 133 | + delta_weight = (quantized_weight - reference_weight).to(self.XTX.dtype) |
| 134 | + return (delta_weight @ self.XTX).flatten() @ delta_weight.flatten() / self.quantized_weight.out_features |
| 135 | + |
| 136 | + def _substitute_and_compute_mse(self, overrides: nn.ParameterDict, selection: slice) -> torch.Tensor: |
| 137 | + """Utility for parallelism: replace the specified parameters of self.quantized_weight, then compute MSE""" |
| 138 | + for param_name, param_value in overrides.items(): |
| 139 | + replace_parameter_(self.quantized_weight, param_name, param_value) |
| 140 | + return self._compute_mse(selection) |
| 141 | + |
| 142 | + def _compute_mse_parallel( |
| 143 | + self, devices: Sequence[torch.device], replicas: Sequence[AQEngine], parameters_to_replicate: nn.ParameterDict |
| 144 | + ) -> torch.Tensor: |
| 145 | + """Compute MSE in parallel over output channels""" |
| 146 | + replicated_parameters = torch.nn.parallel.replicate(parameters_to_replicate, devices, detach=False) |
| 147 | + num_output_groups = self.quantized_weight.out_features // self.quantized_weight.out_group_size |
| 148 | + shard_size = (num_output_groups - 1) // len(devices) + 1 |
| 149 | + active_slices_by_replica = [ |
| 150 | + slice(i * shard_size, min((i + 1) * shard_size, num_output_groups)) for i in range(len(devices)) |
| 151 | + ] |
| 152 | + funcs_by_replica = [replica._substitute_and_compute_mse for replica in replicas] |
| 153 | + inputs_by_replica = [(dict(), active_slices_by_replica[0])] # no overrides needed for 0-th replica |
| 154 | + for i in range(1, len(devices)): |
| 155 | + inputs_by_replica.append((replicated_parameters[i], active_slices_by_replica[i])) |
| 156 | + mse_components = torch.nn.parallel.parallel_apply(funcs_by_replica, inputs_by_replica, devices=devices) |
| 157 | + return Gather.apply(devices[0], 0, *(mse.view(1) for mse in mse_components)).sum() |
| 158 | + |
| 159 | + def _substitute_and_beam_search(self, overrides: nn.ParameterDict, selection: slice, **kwargs) -> torch.Tensor: |
| 160 | + """Utility for parallelism: replace the specified parameters of self.quantized_weight, then run beam search""" |
| 161 | + dtype = self.quantized_weight.codebooks.dtype |
| 162 | + for param_name, param_value in overrides.items(): |
| 163 | + replace_parameter_(self.quantized_weight, param_name, param_value) |
| 164 | + out_channel_selection = slice( |
| 165 | + selection.start * self.quantized_weight.out_group_size, |
| 166 | + selection.stop * self.quantized_weight.out_group_size, |
| 167 | + ) |
| 168 | + reference_weight = self.layer.weight.detach()[out_channel_selection].to(dtype) |
| 169 | + return self.quantized_weight.beam_search_update_codes_( |
| 170 | + self.XTX.to(dtype), reference_weight, selection=selection, **kwargs |
| 171 | + ).clone() |
| 172 | + |
| 173 | + @torch.no_grad() |
| 174 | + def beam_search_update_codes_( |
| 175 | + self, |
| 176 | + devices: Sequence[torch.device], |
| 177 | + replicas: Sequence[AQEngine], |
| 178 | + parameters_to_replicate: nn.ParameterDict, |
| 179 | + seed: Optional[int] = None, |
| 180 | + **kwargs, |
| 181 | + ): |
| 182 | + """Update self.quantized_weight.codes in-place via beam search""" |
| 183 | + if len(devices) == 1: # single device |
| 184 | + assert replicas is None |
| 185 | + dtype = self.quantized_weight.codebooks.dtype |
| 186 | + self.quantized_weight.beam_search_update_codes_( |
| 187 | + self.XTX.to(dtype), self.layer.weight.detach().to(dtype), dim_rng=random.Random(seed), **kwargs |
| 188 | + ) |
| 189 | + else: |
| 190 | + assert replicas[0] is self |
| 191 | + replicated_parameters = torch.nn.parallel.replicate(parameters_to_replicate, devices) |
| 192 | + num_output_groups = self.quantized_weight.out_features // self.quantized_weight.out_group_size |
| 193 | + shard_size = (num_output_groups - 1) // len(devices) + 1 |
| 194 | + active_slices_by_replica = [ |
| 195 | + slice(i * shard_size, min((i + 1) * shard_size, num_output_groups)) for i in range(len(devices)) |
| 196 | + ] |
| 197 | + |
| 198 | + funcs_by_replica = [replica._substitute_and_beam_search for replica in replicas] |
| 199 | + inputs_by_replica = [(dict(), active_slices_by_replica[0])] |
| 200 | + for i in range(1, len(devices)): |
| 201 | + inputs_by_replica.append((replicated_parameters[i], active_slices_by_replica[i])) |
| 202 | + kwargs_by_replica = [dict(kwargs, dim_rng=random.Random(seed)) for _ in range(len(devices))] |
| 203 | + new_code_parts_by_replica = torch.nn.parallel.parallel_apply( |
| 204 | + funcs_by_replica, inputs_by_replica, kwargs_by_replica, devices=devices |
| 205 | + ) |
| 206 | + # gather all code parts and assign them to each replica |
| 207 | + for device, replica in zip(devices, replicas): |
| 208 | + replica.quantized_weight.codes[...] = Gather.apply(device, 0, *new_code_parts_by_replica) |
| 209 | + |
| 210 | + |
| 211 | +def replace_parameter_(module: nn.Module, name: str, new_value: torch.Tensor): |
| 212 | + """A hacky way to substitute an already registered parameter with a non-parameter tensor. Breaks future use.""" |
| 213 | + if name in module._parameters: |
| 214 | + module._parameters[name] = new_value |
| 215 | + else: |
| 216 | + setattr(module, name, new_value) |
0 commit comments