@@ -141,8 +141,18 @@ def forward(self, input: Tensor) -> Tensor:
141
141
142
142
143
143
class Params4bit (torch .nn .Parameter ):
144
-
145
- def __new__ (cls , data : Optional [torch .Tensor ] = None , requires_grad = True , quant_state : QuantState = None , blocksize : int = 64 , compress_statistics : bool = True , quant_type : str = 'fp4' ) -> "Params4bit" :
144
+ def __new__ (
145
+ cls ,
146
+ data : Optional [torch .Tensor ] = None ,
147
+ requires_grad = True ,
148
+ quant_state : QuantState = None ,
149
+ blocksize : int = 64 ,
150
+ compress_statistics : bool = True ,
151
+ quant_type : str = 'fp4' ,
152
+ quant_storage : torch .dtype = torch .uint8 ,
153
+ module : Optional ["Linear4bit" ] = None ,
154
+ bnb_quantized : bool = False
155
+ ) -> "Params4bit" :
146
156
if data is None :
147
157
data = torch .empty (0 )
148
158
@@ -151,7 +161,10 @@ def __new__(cls, data: Optional[torch.Tensor] = None, requires_grad=True, quant_
151
161
self .compress_statistics = compress_statistics
152
162
self .quant_type = quant_type
153
163
self .quant_state = quant_state
164
+ self .quant_storage = quant_storage
165
+ self .bnb_quantized = bnb_quantized
154
166
self .data = data
167
+ self .module = module
155
168
return self
156
169
157
170
@classmethod
@@ -162,16 +175,23 @@ def from_prequantized(cls, data: torch.Tensor, quantized_stats: Dict[str, Any],
162
175
self .blocksize = self .quant_state .blocksize
163
176
self .compress_statistics = self .quant_state .nested
164
177
self .quant_type = self .quant_state .quant_type
178
+ self .bnb_quantized = True
165
179
return self
166
180
167
- def cuda (self , device ):
168
- w = self .data .contiguous ().half ().cuda (device )
169
- w_4bit , quant_state = bnb .functional .quantize_4bit (w , blocksize = self .blocksize , compress_statistics = self .compress_statistics , quant_type = self .quant_type )
181
+ def _quantize (self , device ):
182
+ w = self .data .contiguous ().cuda (device )
183
+ w_4bit , quant_state = bnb .functional .quantize_4bit (w , blocksize = self .blocksize , compress_statistics = self .compress_statistics ,
184
+ quant_type = self .quant_type , quant_storage = self .quant_storage )
170
185
self .data = w_4bit
171
186
self .quant_state = quant_state
172
-
187
+ if self .module is not None :
188
+ self .module .quant_state = quant_state
189
+ self .bnb_quantized = True
173
190
return self
174
191
192
+ def cuda (self , device : Optional [Union [int , device , str ]] = None , non_blocking : bool = False ):
193
+ return self .to (device = 'cuda' if device is None else device , non_blocking = non_blocking )
194
+
175
195
@overload
176
196
def to (self : T , device : Optional [Union [int , device ]] = ..., dtype : Optional [Union [dtype , str ]] = ..., non_blocking : bool = ...,) -> T :
177
197
...
@@ -187,8 +207,8 @@ def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T:
187
207
def to (self , * args , ** kwargs ):
188
208
device , dtype , non_blocking , convert_to_format = torch ._C ._nn ._parse_to (* args , ** kwargs )
189
209
190
- if (device is not None and device .type == "cuda" and self .data . device . type == "cpu" ):
191
- return self .cuda (device )
210
+ if (device is not None and device .type == "cuda" and not self .bnb_quantized ):
211
+ return self ._quantize (device )
192
212
else :
193
213
if self .quant_state is not None :
194
214
self .quant_state .to (device )
@@ -203,12 +223,14 @@ def to(self, *args, **kwargs):
203
223
204
224
class Linear4bit (nn .Linear ):
205
225
206
- def __init__ (self , input_features , output_features , bias = True , compute_dtype = None , compress_statistics = True , quant_type = 'fp4' , device = None ):
226
+ def __init__ (self , input_features , output_features , bias = True , compute_dtype = None , compress_statistics = True , quant_type = 'fp4' , quant_storage = torch . uint8 , device = None ):
207
227
super ().__init__ (input_features , output_features , bias , device )
208
- self .weight = Params4bit (self .weight .data , requires_grad = False , compress_statistics = compress_statistics , quant_type = quant_type )
228
+ self .weight = Params4bit (self .weight .data , requires_grad = False , compress_statistics = compress_statistics , quant_type = quant_type , quant_storage = quant_storage , module = self )
209
229
# self.persistent_buffers = [] # TODO consider as way to save quant state
210
230
self .compute_dtype = compute_dtype
211
231
self .compute_type_is_set = False
232
+ self .quant_state = None
233
+ self .quant_storage = quant_storage
212
234
213
235
def set_compute_type (self , x ):
214
236
if x .dtype in [torch .float32 , torch .bfloat16 ]:
@@ -243,7 +265,15 @@ def forward(self, x: torch.Tensor):
243
265
self .bias .data = self .bias .data .to (x .dtype )
244
266
245
267
if getattr (self .weight , 'quant_state' , None ) is None :
246
- print ('FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.' )
268
+ if getattr (self , 'quant_state' , None ) is not None :
269
+ # the quant state got lost when the parameter got converted. This happens for example for fsdp
270
+ # since we registered the module, we can recover the state here
271
+ assert self .weight .shape [1 ] == 1
272
+ if not isinstance (self .weight , Params4bit ):
273
+ self .weight = Params4bit (self .weight , quant_storage = self .quant_storage )
274
+ self .weight .quant_state = self .quant_state
275
+ else :
276
+ print ('FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.' )
247
277
if not self .compute_type_is_set :
248
278
self .set_compute_type (x )
249
279
self .compute_type_is_set = True
@@ -261,8 +291,8 @@ def forward(self, x: torch.Tensor):
261
291
262
292
263
293
class LinearFP4 (Linear4bit ):
264
- def __init__ (self , input_features , output_features , bias = True , compute_dtype = None , compress_statistics = True , device = None ):
265
- super ().__init__ (input_features , output_features , bias , compute_dtype , compress_statistics , 'fp4' , device )
294
+ def __init__ (self , input_features , output_features , bias = True , compute_dtype = None , compress_statistics = True , quant_storage = torch . uint8 , device = None ):
295
+ super ().__init__ (input_features , output_features , bias , compute_dtype , compress_statistics , 'fp4' , quant_storage , device )
266
296
267
297
268
298
class LinearNF4 (Linear4bit ):
@@ -276,8 +306,8 @@ class LinearNF4(Linear4bit):
276
306
Implementation of the NF4 data type in bitsandbytes can be found in the `create_normal_map` function in
277
307
the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236.
278
308
'''
279
- def __init__ (self , input_features , output_features , bias = True , compute_dtype = None , compress_statistics = True , device = None ):
280
- super ().__init__ (input_features , output_features , bias , compute_dtype , compress_statistics , 'nf4' , device )
309
+ def __init__ (self , input_features , output_features , bias = True , compute_dtype = None , compress_statistics = True , quant_storage = torch . uint8 , device = None ):
310
+ super ().__init__ (input_features , output_features , bias , compute_dtype , compress_statistics , 'nf4' , quant_storage , device )
281
311
282
312
283
313
class Int8Params (torch .nn .Parameter ):
0 commit comments