Skip to content

Commit 7a06733

Browse files
authored
Update gptq.py
1 parent 468c47c commit 7a06733

File tree

1 file changed

+17
-2
lines changed

1 file changed

+17
-2
lines changed

gptq.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616

1717
class GPTQ:
18-
1918
def __init__(self, layer):
2019
self.layer = layer
2120
self.dev = self.layer.weight.device
@@ -88,6 +87,10 @@ def fasterquant(
8887
H = torch.cholesky_inverse(H)
8988
H = torch.linalg.cholesky(H, upper=True)
9089
Hinv = H
90+
91+
scale = []
92+
zero = []
93+
now_idx = 1
9194

9295
for i1 in range(0, self.columns, blocksize):
9396
i2 = min(i1 + blocksize, self.columns)
@@ -106,6 +109,11 @@ def fasterquant(
106109
if groupsize != -1:
107110
if (i1 + i) % groupsize == 0:
108111
self.quantizer.find_params(W[:, (i1 + i):(i1 + i + groupsize)], weight=True)
112+
113+
if ((i1 + i) // groupsize) - now_idx == -1:
114+
scale.append(self.quantizer.scale)
115+
zero.append(self.quantizer.zero)
116+
now_idx += 1
109117

110118
q = quantize(
111119
w.unsqueeze(1), self.quantizer.scale, self.quantizer.zero, self.quantizer.maxq
@@ -137,7 +145,14 @@ def fasterquant(
137145
self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype)
138146
if DEBUG:
139147
print(torch.sum((self.layer(self.inp1) - self.out1) ** 2))
140-
148+
149+
if scale == []:
150+
scale.append(self.quantizer.scale)
151+
zero.append(self.quantizer.zero)
152+
scale = torch.cat(scale,dim=1)
153+
zero = torch.cat(zero,dim=1)
154+
return scale,zero
155+
141156
def free(self):
142157
if DEBUG:
143158
self.inp1 = None

0 commit comments

Comments
 (0)