-
Notifications
You must be signed in to change notification settings - Fork 5.9k
[LoRA] loading LoRA into a quantized base model #10550
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
I investigated and this is indeed a bug in PEFT. It is specific to 8bit bnb. Here is a minimal reproducer: import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model
model_id = "facebook/opt-125m"
bnb_config = BitsAndBytesConfig(load_in_8bit=True) # load_in_4bit works
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config)
model = get_peft_model(model, LoraConfig(), low_cpu_mem_usage=True) I will work on a fix. |
There was a bug in PEFT that occurred when trying to use the low_cpu_mem_usage=True option with 8bit bitsandbytes quantized models. This bug is fixed now. See huggingface/diffusers#10550 for the bug report.
Bugfix: huggingface/peft#2325 I tested it locally with the Flux script posted above and the error disappeared. During the forward pass, I ran into OOM, but the issue itself should be addressed by the linked PR. |
this way to load fp8 model and lora still has problem:
and
|
File a separate issue without any |
Sorry, I just think my code is very simple, these are the reproducible codes.
this will lead to:
and when I load other lora that is downloaded from https://civitai.com/, this will lead to:
|
And then, I try another ways to load model:
this leads to:
|
That error will go away if you install Regarding the snippet in #10550 (comment), I think this is a separate problem stemming from
Notice the @BenjaminBossan I know this doesn't directly correspond to Minimally reproducible snippet:import torch
from diffusers import FluxTransformer2DModel, FluxPipeline
from huggingface_hub import hf_hub_download
from optimum.quanto import qfloat8, quantize, freeze
from transformers import T5EncoderModel
bfl_repo = "black-forest-labs/FLUX.1-dev"
dtype = torch.bfloat16
transformer = FluxTransformer2DModel.from_single_file("https://huggingface.co/Kijai/flux-fp8/blob/main/flux1-dev-fp8.safetensors", torch_dtype=dtype)
quantize(transformer, weights=qfloat8)
freeze(transformer)
text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype)
quantize(text_encoder_2, weights=qfloat8)
freeze(text_encoder_2)
pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=None, text_encoder_2=None, torch_dtype=dtype)
pipe.transformer = transformer
pipe.text_encoder_2 = text_encoder_2
pipe.load_lora_weights(
hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd"
) It leads to: Traceback (most recent call last):
File "/home/sayak/diffusers/check_fp8.py", line 22, in <module>
pipe.load_lora_weights(
File "/home/sayak/diffusers/src/diffusers/loaders/lora_pipeline.py", line 1559, in load_lora_weights
transformer_lora_state_dict = self._maybe_expand_lora_state_dict(
File "/home/sayak/diffusers/src/diffusers/loaders/lora_pipeline.py", line 2081, in _maybe_expand_lora_state_dict
base_weight_param = transformer_state_dict[base_param_name]
KeyError: 'context_embedder.weight' In my mind, I can think of adding additional checks here in: diffusers/src/diffusers/loaders/lora_pipeline.py Line 2080 in 74b6752
But perhaps some kind of a dispatcher would make more sense here. EDIT: I investigated a bit and I think even if we determine the ErrorTraceback (most recent call last):
File "/home/sayak/diffusers/check_fp8.py", line 22, in <module>
pipe.load_lora_weights(
File "/home/sayak/diffusers/src/diffusers/loaders/lora_pipeline.py", line 1564, in load_lora_weights
self.load_lora_into_transformer(
File "/home/sayak/diffusers/src/diffusers/loaders/lora_pipeline.py", line 1628, in load_lora_into_transformer
transformer.load_lora_adapter(
File "/home/sayak/diffusers/src/diffusers/loaders/peft.py", line 302, in load_lora_adapter
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
File "/home/sayak/peft/src/peft/utils/save_and_load.py", line 445, in set_peft_model_state_dict
load_result = model.load_state_dict(peft_model_state_dict, strict=False, assign=True)
File "/home/sayak/.pyenv/versions/3.10.12/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2564, in load_state_dict
load(self, state_dict)
File "/home/sayak/.pyenv/versions/3.10.12/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2552, in load
load(child, child_state_dict, child_prefix) # noqa: F821
File "/home/sayak/.pyenv/versions/3.10.12/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2552, in load
load(child, child_state_dict, child_prefix) # noqa: F821
File "/home/sayak/.pyenv/versions/3.10.12/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2552, in load
load(child, child_state_dict, child_prefix) # noqa: F821
[Previous line repeated 1 more time]
File "/home/sayak/.pyenv/versions/3.10.12/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2535, in load
module._load_from_state_dict(
File "/home/sayak/.pyenv/versions/3.10.12/envs/diffusers/lib/python3.10/site-packages/optimum/quanto/nn/qmodule.py", line 160, in _load_from_state_dict
deserialized_weight = WeightQBytesTensor.load_from_state_dict(
File "/home/sayak/.pyenv/versions/3.10.12/envs/diffusers/lib/python3.10/site-packages/optimum/quanto/tensor/weights/qbytes.py", line 77, in load_from_state_dict
inner_tensors_dict[name] = state_dict.pop(prefix + name)
KeyError: 'time_text_embed.timestep_embedder.linear_1.base_layer.weight._data' Changesdiff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py
index 7492ba028..812a5c883 100644
--- a/src/diffusers/loaders/lora_pipeline.py
+++ b/src/diffusers/loaders/lora_pipeline.py
@@ -2053,6 +2053,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
@classmethod
def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict):
+ from optimum.quanto import QLinear
expanded_module_names = set()
transformer_state_dict = transformer.state_dict()
prefix = f"{cls.transformer_name}."
@@ -2068,6 +2069,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
logger.debug(f"Found unexpected modules: {unexpected_modules}. These will be ignored.")
is_peft_loaded = getattr(transformer, "peft_config", None) is not None
+ is_quanto_quantized = any(isinstance(m, QLinear) for _, m in transformer.named_modules())
for k in lora_module_names:
if k in unexpected_modules:
continue
@@ -2077,6 +2079,8 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
if is_peft_loaded and f"{k.replace(prefix, '')}.base_layer.weight" in transformer_state_dict
else f"{k.replace(prefix, '')}.weight"
)
+ if is_quanto_quantized:
+ base_param_name += "._data"
base_weight_param = transformer_state_dict[base_param_name]
lora_A_param = lora_state_dict[f"{prefix}{k}.lora_A.weight"] |
Even after I installed the peft by
|
I said:
This means you need to do: |
ok, sorry for my careless and very very thanks for your help!!! |
But I found that this way slows down the sampling speed by nearly 1.5 times white the optimum does not |
That is a separate issue and we can address that in #10550 (comment) |
Yes, the problem caused by Hyper-FLUX.1-dev-8steps-lora.safetensors can be solved by this way.
This would lead to:
|
I think you misunderstood what I said in #10550 (comment) regarding Quanto support. It's not supported yet. In fact the error you posted is already there in #10550 (comment). Could you try to please read the things carefully before repeatedly commenting? This has happened multiple times now in different threads. |
Sorry for that, may be my English reading ability needs to improved. |
@sayakpaul I haven't checked the details of your quanto issue yet, but note that at the moment, quanto is not supported in PEFT. There is a PR, huggingface/peft#2000, which has become a bit stale at this point, but I plan to finish it soon🤞. @lhjlhj11 Just an update, the PR with the fix is merged to PEFT, so you can now install it from the main branch instead of |
Issue has been fixed for We still have to fix the case for Control LoRAs from FLUX. Will open a separate issue to track and work on it. |
Does the speed of sampling influenced by bitsandbytes? |
Similar issues:
Reproduction
Happens on
main
as well asv0.31.0-release
branch as well.Error
@BenjaminBossan any suggestions here?
The text was updated successfully, but these errors were encountered: