15
15
16
16
17
17
class GPTQ :
18
-
19
18
def __init__ (self , layer ):
20
19
self .layer = layer
21
20
self .dev = self .layer .weight .device
@@ -88,6 +87,10 @@ def fasterquant(
88
87
H = torch .cholesky_inverse (H )
89
88
H = torch .linalg .cholesky (H , upper = True )
90
89
Hinv = H
90
+
91
+ scale = []
92
+ zero = []
93
+ now_idx = 1
91
94
92
95
for i1 in range (0 , self .columns , blocksize ):
93
96
i2 = min (i1 + blocksize , self .columns )
@@ -106,6 +109,11 @@ def fasterquant(
106
109
if groupsize != - 1 :
107
110
if (i1 + i ) % groupsize == 0 :
108
111
self .quantizer .find_params (W [:, (i1 + i ):(i1 + i + groupsize )], weight = True )
112
+
113
+ if ((i1 + i ) // groupsize ) - now_idx == - 1 :
114
+ scale .append (self .quantizer .scale )
115
+ zero .append (self .quantizer .zero )
116
+ now_idx += 1
109
117
110
118
q = quantize (
111
119
w .unsqueeze (1 ), self .quantizer .scale , self .quantizer .zero , self .quantizer .maxq
@@ -137,7 +145,14 @@ def fasterquant(
137
145
self .layer .weight .data = Q .reshape (self .layer .weight .shape ).to (self .layer .weight .data .dtype )
138
146
if DEBUG :
139
147
print (torch .sum ((self .layer (self .inp1 ) - self .out1 ) ** 2 ))
140
-
148
+
149
+ if scale == []:
150
+ scale .append (self .quantizer .scale )
151
+ zero .append (self .quantizer .zero )
152
+ scale = torch .cat (scale ,dim = 1 )
153
+ zero = torch .cat (zero ,dim = 1 )
154
+ return scale ,zero
155
+
141
156
def free (self ):
142
157
if DEBUG :
143
158
self .inp1 = None
0 commit comments