Skip to content

Commit dcfb6f8

Browse files
Initial FSDP Support for QLoRA Finetuning (#970)
This PR adds initial FSDP support for training QLoRA models. It enables basic FSDP and CPU Offload support, with low memory training via FSDP.sync_module_states option unsupported. This PR builds off of #840 commit 8278fca and BNB FSDP by @TimDettmers and @Titus-von-Koeller. An example of using this PR to finetune QLoRA models with FSDP can be found in the demo repo: AnswerDotAi/fsdp_qlora. * Minimal changes for fp32 4bit storage from BNB commit 8278fca * Params4bit with selectable storage dtype * possible fix for double quantizing linear weight & quant storage dtype * minor fixes in Params4bit for peft tests * remove redundant * add float16 * update test * Remove float16 quant cast as there are fp32, bf16, & fp16 quant kernels --------- Co-authored-by: Kerem Turgutlu <[email protected]>
1 parent 64a28d0 commit dcfb6f8

File tree

4 files changed

+99
-33
lines changed

4 files changed

+99
-33
lines changed

bitsandbytes/functional.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -607,7 +607,7 @@ def from_dict(cls, qs_dict: Dict[str, Any], device: torch.device) -> 'QuantState
607607
608608
qs_dict: based on state_dict, with only relevant keys, striped of prefixes.
609609
610-
item with key `quant_state.bitsandbytes__[nf4/fp4]` may contain minor and non-tensor quant state items.
610+
item with key `quant_state.bitsandbytes__[nf4/fp4]` may contain minor and non-tensor quant state items.
611611
"""
612612

613613
# unpacking tensor with non-tensor components
@@ -802,7 +802,7 @@ def dequantize_blockwise(
802802

803803
if quant_state is None:
804804
quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=torch.float32)
805-
805+
806806
absmax = quant_state.absmax
807807
if quant_state.nested:
808808
absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2)
@@ -884,13 +884,13 @@ def get_4bit_type(typename, device=None, blocksize=64):
884884
return data.to(device)
885885

886886

887-
def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False):
888-
return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'fp4')
887+
def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False, quant_storage=torch.uint8):
888+
return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'fp4', quant_storage)
889889

890-
def quantize_nf4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False):
891-
return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'nf4')
890+
def quantize_nf4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False, quant_storage=torch.uint8):
891+
return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'nf4', quant_storage)
892892

893-
def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False, quant_type='fp4') -> Tensor:
893+
def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False, quant_type='fp4', quant_storage=torch.uint8) -> Tensor:
894894
"""
895895
Quantize tensor A in blocks of 4-bit values.
896896
@@ -903,7 +903,7 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz
903903
absmax : torch.Tensor
904904
The absmax values.
905905
out : torch.Tensor
906-
The output tensor (8-bit).
906+
The output tensor.
907907
blocksize : int
908908
The blocksize used in quantization.
909909
quant_type : str
@@ -912,7 +912,7 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz
912912
Returns
913913
-------
914914
torch.Tensor:
915-
The 8-bit tensor with packed 4-bit values.
915+
Tensor with packed 4-bit values.
916916
tuple(torch.Tensor, torch.Size, torch.dtype, int):
917917
The quantization state to undo the quantization.
918918
"""
@@ -931,7 +931,8 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz
931931

932932

933933
if out is None:
934-
out = torch.zeros(((n+1)//2, 1), dtype=torch.uint8, device=A.device)
934+
mod = dtype2bytes[quant_storage] * 2
935+
out = torch.zeros(((n+1)//mod, 1), dtype=quant_storage, device=A.device)
935936

936937
assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64]
937938

@@ -985,7 +986,7 @@ def dequantize_4bit(A: Tensor, quant_state: QuantState = None, absmax: Tensor =
985986
Parameters
986987
----------
987988
A : torch.Tensor
988-
The input 8-bit tensor (packed 4-bit values).
989+
The input tensor (packed 4-bit values).
989990
quant_state : QuantState
990991
object with quantisation stats, incl. absmax values, original tensor shape and original dtype.
991992
absmax : torch.Tensor
@@ -1626,7 +1627,7 @@ def gemv_4bit(
16261627
ldb = ct.c_int32(ldb)
16271628
ldc = ct.c_int32(ldc)
16281629

1629-
if B.dtype == torch.uint8:
1630+
if B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32]:
16301631
if A.dtype == torch.float16:
16311632
lib.cgemm_4bit_inference_naive_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(state.blocksize))
16321633
elif A.dtype == torch.bfloat16:

bitsandbytes/nn/modules.py

Lines changed: 45 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,18 @@ def forward(self, input: Tensor) -> Tensor:
141141

142142

143143
class Params4bit(torch.nn.Parameter):
144-
145-
def __new__(cls, data: Optional[torch.Tensor] = None, requires_grad=True, quant_state: QuantState = None, blocksize: int = 64, compress_statistics: bool = True, quant_type: str = 'fp4') -> "Params4bit":
144+
def __new__(
145+
cls,
146+
data: Optional[torch.Tensor] = None,
147+
requires_grad=True,
148+
quant_state: QuantState = None,
149+
blocksize: int = 64,
150+
compress_statistics: bool = True,
151+
quant_type: str = 'fp4',
152+
quant_storage: torch.dtype = torch.uint8,
153+
module: Optional["Linear4bit"] = None,
154+
bnb_quantized: bool = False
155+
) -> "Params4bit":
146156
if data is None:
147157
data = torch.empty(0)
148158

@@ -151,7 +161,10 @@ def __new__(cls, data: Optional[torch.Tensor] = None, requires_grad=True, quant_
151161
self.compress_statistics = compress_statistics
152162
self.quant_type = quant_type
153163
self.quant_state = quant_state
164+
self.quant_storage = quant_storage
165+
self.bnb_quantized = bnb_quantized
154166
self.data = data
167+
self.module = module
155168
return self
156169

157170
@classmethod
@@ -162,16 +175,23 @@ def from_prequantized(cls, data: torch.Tensor, quantized_stats: Dict[str, Any],
162175
self.blocksize = self.quant_state.blocksize
163176
self.compress_statistics = self.quant_state.nested
164177
self.quant_type = self.quant_state.quant_type
178+
self.bnb_quantized = True
165179
return self
166180

167-
def cuda(self, device):
168-
w = self.data.contiguous().half().cuda(device)
169-
w_4bit, quant_state = bnb.functional.quantize_4bit(w, blocksize=self.blocksize, compress_statistics=self.compress_statistics, quant_type=self.quant_type)
181+
def _quantize(self, device):
182+
w = self.data.contiguous().cuda(device)
183+
w_4bit, quant_state = bnb.functional.quantize_4bit(w, blocksize=self.blocksize, compress_statistics=self.compress_statistics,
184+
quant_type=self.quant_type, quant_storage=self.quant_storage)
170185
self.data = w_4bit
171186
self.quant_state = quant_state
172-
187+
if self.module is not None:
188+
self.module.quant_state = quant_state
189+
self.bnb_quantized = True
173190
return self
174191

192+
def cuda(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False):
193+
return self.to(device='cuda' if device is None else device, non_blocking=non_blocking)
194+
175195
@overload
176196
def to(self: T, device: Optional[Union[int, device]] = ..., dtype: Optional[Union[dtype, str]] = ..., non_blocking: bool = ...,) -> T:
177197
...
@@ -187,8 +207,8 @@ def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T:
187207
def to(self, *args, **kwargs):
188208
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
189209

190-
if (device is not None and device.type == "cuda" and self.data.device.type == "cpu"):
191-
return self.cuda(device)
210+
if (device is not None and device.type == "cuda" and not self.bnb_quantized):
211+
return self._quantize(device)
192212
else:
193213
if self.quant_state is not None:
194214
self.quant_state.to(device)
@@ -203,12 +223,14 @@ def to(self, *args, **kwargs):
203223

204224
class Linear4bit(nn.Linear):
205225

206-
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4', device=None):
226+
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4', quant_storage=torch.uint8, device=None):
207227
super().__init__(input_features, output_features, bias, device)
208-
self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type)
228+
self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type, quant_storage=quant_storage, module=self)
209229
# self.persistent_buffers = [] # TODO consider as way to save quant state
210230
self.compute_dtype = compute_dtype
211231
self.compute_type_is_set = False
232+
self.quant_state = None
233+
self.quant_storage = quant_storage
212234

213235
def set_compute_type(self, x):
214236
if x.dtype in [torch.float32, torch.bfloat16]:
@@ -243,7 +265,15 @@ def forward(self, x: torch.Tensor):
243265
self.bias.data = self.bias.data.to(x.dtype)
244266

245267
if getattr(self.weight, 'quant_state', None) is None:
246-
print('FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.')
268+
if getattr(self, 'quant_state', None) is not None:
269+
# the quant state got lost when the parameter got converted. This happens for example for fsdp
270+
# since we registered the module, we can recover the state here
271+
assert self.weight.shape[1] == 1
272+
if not isinstance(self.weight, Params4bit):
273+
self.weight = Params4bit(self.weight, quant_storage=self.quant_storage)
274+
self.weight.quant_state = self.quant_state
275+
else:
276+
print('FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.')
247277
if not self.compute_type_is_set:
248278
self.set_compute_type(x)
249279
self.compute_type_is_set = True
@@ -261,8 +291,8 @@ def forward(self, x: torch.Tensor):
261291

262292

263293
class LinearFP4(Linear4bit):
264-
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, device=None):
265-
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4', device)
294+
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_storage=torch.uint8, device=None):
295+
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4', quant_storage, device)
266296

267297

268298
class LinearNF4(Linear4bit):
@@ -276,8 +306,8 @@ class LinearNF4(Linear4bit):
276306
Implementation of the NF4 data type in bitsandbytes can be found in the `create_normal_map` function in
277307
the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236.
278308
'''
279-
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, device=None):
280-
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4', device)
309+
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_storage=torch.uint8, device=None):
310+
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4', quant_storage, device)
281311

282312

283313
class Int8Params(torch.nn.Parameter):

tests/test_functional.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2370,7 +2370,8 @@ def test_normal_map_tree():
23702370
@pytest.mark.parametrize("storage_type", ['nf4', 'fp4'], ids=['nf4', 'fp4'])
23712371
@pytest.mark.parametrize("kind", ['fc1', 'fc2', 'attn', 'attn_packed'], ids=['fc1', 'fc2', 'attn', 'attn_packed'])
23722372
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=['fp16', 'bf16', 'fp32'])
2373-
def test_gemv_4bit(dtype, storage_type, double_quant, kind):
2373+
@pytest.mark.parametrize("quant_storage", [torch.uint8, torch.float16, torch.bfloat16, torch.float32], ids=['uint8', 'fp16', 'bf16', 'fp32'])
2374+
def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind):
23742375
for dim in [128, 256, 512, 1024]:
23752376
#for dim in [4*1024]:
23762377
#for dim in [1*16]:
@@ -2399,7 +2400,7 @@ def test_gemv_4bit(dtype, storage_type, double_quant, kind):
23992400
A = torch.randn(1, dim, dtype=dtype, device='cuda')
24002401
B = torch.randn(dim*3, dim, dtype=dtype, device='cuda')/math.sqrt(dim)
24012402

2402-
qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant)
2403+
qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant, quant_storage=quant_storage)
24032404
C3 = torch.matmul(A, B.t())
24042405
C2 = F.gemv_4bit(A, qB.t(), state=state)
24052406
A.requires_grad = True

tests/test_linear4bit.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,19 @@
88

99
import bitsandbytes as bnb
1010

11+
storage = {
12+
'uint8': torch.uint8,
13+
'float16': torch.float16,
14+
'bfloat16': torch.bfloat16,
15+
'float32': torch.float32
16+
}
1117

1218
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
1319
@pytest.mark.parametrize(
14-
"quant_type, compress_statistics, bias",
15-
list(product(["nf4", "fp4"], [False, True], [False, True])),
20+
"quant_type, compress_statistics, bias, quant_storage",
21+
list(product(["nf4", "fp4"], [False, True], [False, True], ['uint8', 'float16', 'bfloat16', 'float32'])),
1622
)
17-
def test_linear_serialization(quant_type, compress_statistics, bias):
23+
def test_linear_serialization(quant_type, compress_statistics, bias, quant_storage):
1824
original_dtype = torch.float16
1925
compute_dtype = None
2026
device = "cuda"
@@ -32,7 +38,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
3238
quant_type=quant_type,
3339
device="meta",
3440
)
35-
new_weight = bnb.nn.Params4bit(data=linear.weight, requires_grad=False)
41+
new_weight = bnb.nn.Params4bit(data=linear.weight, quant_type=quant_type, requires_grad=False)
3642
linear_q.weight = new_weight
3743
if bias:
3844
linear_q.bias = torch.nn.Parameter(linear.bias)
@@ -65,6 +71,22 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
6571
# MATCHING
6672
a, b = linear_q.weight, linear_q2.weight
6773

74+
# Quantizing original layer with specified quant_storage type
75+
linear_qs = bnb.nn.Linear4bit(
76+
linear.in_features,
77+
linear.out_features,
78+
bias=bias,
79+
compute_dtype=compute_dtype,
80+
compress_statistics=compress_statistics,
81+
quant_type=quant_type,
82+
quant_storage=storage[quant_storage],
83+
device="meta",
84+
)
85+
linear_qs.weight = bnb.nn.Params4bit(data=linear.weight, requires_grad=False, quant_type=quant_type, quant_storage=storage[quant_storage])
86+
if bias:
87+
linear_qs.bias = torch.nn.Parameter(linear.bias)
88+
linear_qs = linear_qs.to(device)
89+
6890
assert a.device == b.device
6991
assert a.dtype == b.dtype
7092
assert torch.equal(a, b)
@@ -96,9 +118,21 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
96118
x = torch.rand(42, layer_shape[0], device=device)
97119
a = linear_q(x)
98120
b = linear_q2(x)
121+
c = linear_qs(x)
99122
assert a.device == b.device
100123
assert a.dtype == b.dtype
124+
assert a.device == c.device
125+
assert a.dtype == c.dtype
101126
assert torch.equal(a, b)
127+
assert torch.equal(a, c)
128+
129+
# Test moving to CPU and back to GPU
130+
linear_q2.to('cpu')
131+
linear_q2.to(device)
132+
d = linear_qs(x)
133+
assert c.dtype == d.dtype
134+
assert c.device == d.device
135+
assert torch.equal(c, d)
102136

103137
# Saved size ratio test. Target set for layer_shape == (300, 400) w/ bias
104138
with TemporaryDirectory() as tmpdir:

0 commit comments

Comments
 (0)