Skip to content

Commit a270974

Browse files
authored
Update quant.py
1 parent 7a06733 commit a270974

File tree

1 file changed

+63
-15
lines changed

1 file changed

+63
-15
lines changed

quant.py

+63-15
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numpy as np
22
import torch
33
import torch.nn as nn
4-
4+
import math
55

66
def quantize(x, scale, zero, maxq):
77
q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
@@ -126,25 +126,37 @@ def ready(self):
126126

127127
# Assumes layer is perfectly divisible into 256 * 256 blocks
128128
class QuantLinear(nn.Module):
129-
def __init__(self, bits, infeatures, outfeatures):
129+
def __init__(self, bits, groupsize, infeatures, outfeatures):
130130
super().__init__()
131131
if bits not in [2,3,4,8]:
132132
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
133+
self.infeatures = infeatures
134+
self.outfeatures = outfeatures
133135
self.bits = bits
134-
self.register_buffer('zeros', torch.zeros((outfeatures, 1)))
135-
self.register_buffer('scales', torch.zeros((outfeatures, 1)))
136+
if groupsize != -1 and groupsize < 32 and groupsize != int(math.pow(2,int(math.log2(groupsize)))):
137+
raise NotImplementedError("groupsize supports powers of 2 greater than 32. (e.g. : 32,64,128,etc)")
138+
groupsize = groupsize if groupsize != -1 else infeatures
139+
self.groupsize = groupsize
140+
self.register_buffer('qzeros', torch.zeros((math.ceil(infeatures/groupsize),outfeatures // 256 * (bits * 8)), dtype=torch.int))
141+
self.register_buffer('scales', torch.zeros((math.ceil(infeatures/groupsize),outfeatures)))
136142
self.register_buffer('bias', torch.zeros(outfeatures))
137143
self.register_buffer(
138144
'qweight', torch.zeros((infeatures // 256 * (bits * 8), outfeatures), dtype=torch.int)
139145
)
140146

141147
def pack(self, linear, scales, zeros):
142-
self.zeros = zeros * scales
148+
scales = scales.t().contiguous()
149+
zeros = zeros.t().contiguous()
150+
scale_zeros = zeros * scales
143151
self.scales = scales.clone()
144152
if linear.bias is not None:
145-
self.bias = linear.bias.clone()
146-
147-
intweight = torch.round((linear.weight.data + self.zeros) / self.scales).to(torch.int)
153+
self.bias = linear.bias.clone()
154+
155+
intweight = []
156+
for idx in range(self.infeatures):
157+
g_idx = idx // self.groupsize
158+
intweight.append(torch.round((linear.weight.data[:,idx] + scale_zeros[g_idx]) / self.scales[g_idx]).to(torch.int)[:,None])
159+
intweight = torch.cat(intweight,dim=1)
148160
intweight = intweight.t().contiguous()
149161
intweight = intweight.numpy().astype(np.uint32)
150162
qweight = np.zeros(
@@ -182,6 +194,42 @@ def pack(self, linear, scales, zeros):
182194

183195
qweight = qweight.astype(np.int32)
184196
self.qweight = torch.from_numpy(qweight)
197+
198+
zeros -= 1;
199+
zeros = zeros.numpy().astype(np.uint32)
200+
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 256 * (self.bits * 8)), dtype=np.uint32)
201+
i = 0
202+
col = 0
203+
while col < qzeros.shape[1]:
204+
if self.bits in [2,4,8]:
205+
for j in range(i, i + (32//self.bits)):
206+
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
207+
i += 32//self.bits
208+
col += 1
209+
elif self.bits == 3:
210+
for j in range(i, i + 10):
211+
qzeros[:, col] |= zeros[:, j] << (3 * (j - i))
212+
i += 10
213+
qzeros[:, col] |= zeros[:, i] << 30
214+
col += 1
215+
qzeros[:, col] |= (zeros[:, i] >> 2) & 1
216+
i += 1
217+
for j in range(i, i + 10):
218+
qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 1)
219+
i += 10
220+
qzeros[:, col] |= zeros[:, i] << 31
221+
col += 1
222+
qzeros[:, col] |= (zeros[:, i] >> 1) & 0x3
223+
i += 1
224+
for j in range(i, i + 10):
225+
qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 2)
226+
i += 10
227+
col += 1
228+
else:
229+
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
230+
231+
qzeros = qzeros.astype(np.int32)
232+
self.qzeros = torch.from_numpy(qzeros)
185233

186234
def forward(self, x):
187235
outshape = list(x.shape)
@@ -191,27 +239,27 @@ def forward(self, x):
191239
dtype = x.dtype
192240
x = x.float()
193241
if self.bits == 2:
194-
quant_cuda.vecquant2matmul(x, self.qweight, y, self.scales, self.zeros)
242+
quant_cuda.vecquant2matmul(x, self.qweight, y, self.scales, self.qzeros, self.groupsize)
195243
elif self.bits == 3:
196-
quant_cuda.vecquant3matmul(x, self.qweight, y, self.scales, self.zeros)
244+
quant_cuda.vecquant3matmul(x, self.qweight, y, self.scales, self.qzeros, self.groupsize)
197245
elif self.bits == 4:
198-
quant_cuda.vecquant4matmul(x, self.qweight, y, self.scales, self.zeros)
246+
quant_cuda.vecquant4matmul(x, self.qweight, y, self.scales, self.qzeros, self.groupsize)
199247
elif self.bits == 8:
200-
quant_cuda.vecquant8matmul(x, self.qweight, y, self.scales, self.zeros)
248+
quant_cuda.vecquant8matmul(x, self.qweight, y, self.scales, self.qzeros, self.groupsize)
201249
else:
202250
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
203251
y = y.to(dtype)
204252
return y.reshape(outshape)
205253

206-
def make_quant(module, names, bits, name=''):
254+
def make_quant(module, names, bits, groupsize, name=''):
207255
if isinstance(module, QuantLinear):
208256
return
209257
for attr in dir(module):
210258
tmp = getattr(module, attr)
211259
name1 = name + '.' + attr if name != '' else attr
212260
if name1 in names:
213261
setattr(
214-
module, attr, QuantLinear(bits, tmp.in_features, tmp.out_features)
262+
module, attr, QuantLinear(bits, groupsize, tmp.in_features, tmp.out_features)
215263
)
216264
for name1, child in module.named_children():
217-
make_quant(child, names, bits, name + '.' + name1 if name != '' else name1)
265+
make_quant(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1)

0 commit comments

Comments
 (0)