Skip to content

[bitsandbytes] Simplify bnb int8 dequant #10401

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Feb 4, 2025
30 changes: 18 additions & 12 deletions src/diffusers/quantizers/bitsandbytes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name


# Copied from PEFT: https://github.com/huggingface/peft/blob/47b3712898539569c02ec5b3ed4a6c36811331a1/src/peft/utils/integrations.py#L41
def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None):
def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None, dtype: "torch.dtype" = None):
"""
Helper function to dequantize 4bit or 8bit bnb weights.

Expand All @@ -177,13 +177,16 @@ def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None):
if state.SCB is None:
state.SCB = weight.SCB

im = torch.eye(weight.data.shape[-1]).contiguous().half().to(weight.device)
im, imt, SCim, SCimt, coo_tensorim = bnb.functional.double_quant(im)
im, Sim = bnb.functional.transform(im, "col32")
if state.CxB is None:
state.CxB, state.SB = bnb.functional.transform(weight.data, to_order=state.formatB)
out32, Sout32 = bnb.functional.igemmlt(im, state.CxB, Sim, state.SB)
return bnb.functional.mm_dequant(out32, Sout32, SCim, state.SCB, bias=None).t()
if hasattr(bnb.functional, "int8_vectorwise_dequant"):
# Use bitsandbytes API if available (requires v0.45.0+)
dequantized = bnb.functional.int8_vectorwise_dequant(weight.data, state.SCB)
else:
# Multiply by (scale/127) to dequantize.
dequantized = weight.data * state.SCB.view(-1, 1) * 7.874015718698502e-3

if dtype:
dequantized = dequantized.to(dtype)
return dequantized


def _create_accelerate_new_hook(old_hook):
Expand All @@ -205,6 +208,7 @@ def _create_accelerate_new_hook(old_hook):

def _dequantize_and_replace(
model,
dtype,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a private method so signature modification should be fine.

modules_to_not_convert=None,
current_key_name=None,
quantization_config=None,
Expand Down Expand Up @@ -244,7 +248,7 @@ def _dequantize_and_replace(
else:
state = None

new_module.weight = torch.nn.Parameter(dequantize_bnb_weight(module.weight, state))
new_module.weight = torch.nn.Parameter(dequantize_bnb_weight(module.weight, state, dtype))

if bias is not None:
new_module.bias = bias
Expand All @@ -263,9 +267,10 @@ def _dequantize_and_replace(
if len(list(module.children())) > 0:
_, has_been_replaced = _dequantize_and_replace(
module,
modules_to_not_convert,
current_key_name,
quantization_config,
dtype=dtype,
modules_to_not_convert=modules_to_not_convert,
current_key_name=current_key_name,
quantization_config=quantization_config,
has_been_replaced=has_been_replaced,
)
# Remove the last key for recursion
Expand All @@ -280,6 +285,7 @@ def dequantize_and_replace(
):
model, has_been_replaced = _dequantize_and_replace(
model,
dtype=model.dtype,
modules_to_not_convert=modules_to_not_convert,
quantization_config=quantization_config,
)
Expand Down
Loading