1
1
import numpy as np
2
2
import torch
3
3
import torch .nn as nn
4
-
4
+ import math
5
5
6
6
def quantize (x , scale , zero , maxq ):
7
7
q = torch .clamp (torch .round (x / scale ) + zero , 0 , maxq )
@@ -126,25 +126,37 @@ def ready(self):
126
126
127
127
# Assumes layer is perfectly divisible into 256 * 256 blocks
128
128
class QuantLinear (nn .Module ):
129
- def __init__ (self , bits , infeatures , outfeatures ):
129
+ def __init__ (self , bits , groupsize , infeatures , outfeatures ):
130
130
super ().__init__ ()
131
131
if bits not in [2 ,3 ,4 ,8 ]:
132
132
raise NotImplementedError ("Only 2,3,4,8 bits are supported." )
133
+ self .infeatures = infeatures
134
+ self .outfeatures = outfeatures
133
135
self .bits = bits
134
- self .register_buffer ('zeros' , torch .zeros ((outfeatures , 1 )))
135
- self .register_buffer ('scales' , torch .zeros ((outfeatures , 1 )))
136
+ if groupsize != - 1 and groupsize < 32 and groupsize != int (math .pow (2 ,int (math .log2 (groupsize )))):
137
+ raise NotImplementedError ("groupsize supports powers of 2 greater than 32. (e.g. : 32,64,128,etc)" )
138
+ groupsize = groupsize if groupsize != - 1 else infeatures
139
+ self .groupsize = groupsize
140
+ self .register_buffer ('qzeros' , torch .zeros ((math .ceil (infeatures / groupsize ),outfeatures // 256 * (bits * 8 )), dtype = torch .int ))
141
+ self .register_buffer ('scales' , torch .zeros ((math .ceil (infeatures / groupsize ),outfeatures )))
136
142
self .register_buffer ('bias' , torch .zeros (outfeatures ))
137
143
self .register_buffer (
138
144
'qweight' , torch .zeros ((infeatures // 256 * (bits * 8 ), outfeatures ), dtype = torch .int )
139
145
)
140
146
141
147
def pack (self , linear , scales , zeros ):
142
- self .zeros = zeros * scales
148
+ scales = scales .t ().contiguous ()
149
+ zeros = zeros .t ().contiguous ()
150
+ scale_zeros = zeros * scales
143
151
self .scales = scales .clone ()
144
152
if linear .bias is not None :
145
- self .bias = linear .bias .clone ()
146
-
147
- intweight = torch .round ((linear .weight .data + self .zeros ) / self .scales ).to (torch .int )
153
+ self .bias = linear .bias .clone ()
154
+
155
+ intweight = []
156
+ for idx in range (self .infeatures ):
157
+ g_idx = idx // self .groupsize
158
+ intweight .append (torch .round ((linear .weight .data [:,idx ] + scale_zeros [g_idx ]) / self .scales [g_idx ]).to (torch .int )[:,None ])
159
+ intweight = torch .cat (intweight ,dim = 1 )
148
160
intweight = intweight .t ().contiguous ()
149
161
intweight = intweight .numpy ().astype (np .uint32 )
150
162
qweight = np .zeros (
@@ -182,6 +194,42 @@ def pack(self, linear, scales, zeros):
182
194
183
195
qweight = qweight .astype (np .int32 )
184
196
self .qweight = torch .from_numpy (qweight )
197
+
198
+ zeros -= 1 ;
199
+ zeros = zeros .numpy ().astype (np .uint32 )
200
+ qzeros = np .zeros ((zeros .shape [0 ], zeros .shape [1 ] // 256 * (self .bits * 8 )), dtype = np .uint32 )
201
+ i = 0
202
+ col = 0
203
+ while col < qzeros .shape [1 ]:
204
+ if self .bits in [2 ,4 ,8 ]:
205
+ for j in range (i , i + (32 // self .bits )):
206
+ qzeros [:, col ] |= zeros [:, j ] << (self .bits * (j - i ))
207
+ i += 32 // self .bits
208
+ col += 1
209
+ elif self .bits == 3 :
210
+ for j in range (i , i + 10 ):
211
+ qzeros [:, col ] |= zeros [:, j ] << (3 * (j - i ))
212
+ i += 10
213
+ qzeros [:, col ] |= zeros [:, i ] << 30
214
+ col += 1
215
+ qzeros [:, col ] |= (zeros [:, i ] >> 2 ) & 1
216
+ i += 1
217
+ for j in range (i , i + 10 ):
218
+ qzeros [:, col ] |= zeros [:, j ] << (3 * (j - i ) + 1 )
219
+ i += 10
220
+ qzeros [:, col ] |= zeros [:, i ] << 31
221
+ col += 1
222
+ qzeros [:, col ] |= (zeros [:, i ] >> 1 ) & 0x3
223
+ i += 1
224
+ for j in range (i , i + 10 ):
225
+ qzeros [:, col ] |= zeros [:, j ] << (3 * (j - i ) + 2 )
226
+ i += 10
227
+ col += 1
228
+ else :
229
+ raise NotImplementedError ("Only 2,3,4,8 bits are supported." )
230
+
231
+ qzeros = qzeros .astype (np .int32 )
232
+ self .qzeros = torch .from_numpy (qzeros )
185
233
186
234
def forward (self , x ):
187
235
outshape = list (x .shape )
@@ -191,27 +239,27 @@ def forward(self, x):
191
239
dtype = x .dtype
192
240
x = x .float ()
193
241
if self .bits == 2 :
194
- quant_cuda .vecquant2matmul (x , self .qweight , y , self .scales , self .zeros )
242
+ quant_cuda .vecquant2matmul (x , self .qweight , y , self .scales , self .qzeros , self . groupsize )
195
243
elif self .bits == 3 :
196
- quant_cuda .vecquant3matmul (x , self .qweight , y , self .scales , self .zeros )
244
+ quant_cuda .vecquant3matmul (x , self .qweight , y , self .scales , self .qzeros , self . groupsize )
197
245
elif self .bits == 4 :
198
- quant_cuda .vecquant4matmul (x , self .qweight , y , self .scales , self .zeros )
246
+ quant_cuda .vecquant4matmul (x , self .qweight , y , self .scales , self .qzeros , self . groupsize )
199
247
elif self .bits == 8 :
200
- quant_cuda .vecquant8matmul (x , self .qweight , y , self .scales , self .zeros )
248
+ quant_cuda .vecquant8matmul (x , self .qweight , y , self .scales , self .qzeros , self . groupsize )
201
249
else :
202
250
raise NotImplementedError ("Only 2,3,4,8 bits are supported." )
203
251
y = y .to (dtype )
204
252
return y .reshape (outshape )
205
253
206
- def make_quant (module , names , bits , name = '' ):
254
+ def make_quant (module , names , bits , groupsize , name = '' ):
207
255
if isinstance (module , QuantLinear ):
208
256
return
209
257
for attr in dir (module ):
210
258
tmp = getattr (module , attr )
211
259
name1 = name + '.' + attr if name != '' else attr
212
260
if name1 in names :
213
261
setattr (
214
- module , attr , QuantLinear (bits , tmp .in_features , tmp .out_features )
262
+ module , attr , QuantLinear (bits , groupsize , tmp .in_features , tmp .out_features )
215
263
)
216
264
for name1 , child in module .named_children ():
217
- make_quant (child , names , bits , name + '.' + name1 if name != '' else name1 )
265
+ make_quant (child , names , bits , groupsize , name + '.' + name1 if name != '' else name1 )
0 commit comments