Skip to content

[LoRA] loading LoRA into a quantized base model #10550

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
sayakpaul opened this issue Jan 13, 2025 · 18 comments
Closed

[LoRA] loading LoRA into a quantized base model #10550

sayakpaul opened this issue Jan 13, 2025 · 18 comments
Assignees
Labels

Comments

@sayakpaul
Copy link
Member

Similar issues:

  1. [LoRA] Quanto Flux LoRA can't load #10512
  2. NF4 quantized flux models with loras #10496
Reproduction
import torch
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, FluxTransformer2DModel, FluxPipeline
from huggingface_hub import hf_hub_download


transformer_8bit = FluxTransformer2DModel.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    subfolder="transformer",
    quantization_config=DiffusersBitsAndBytesConfig(load_in_8bit=True),
    torch_dtype=torch.bfloat16,
)
pipe = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    transformer=transformer_8bit,
    torch_dtype=torch.bfloat16,
).to("cuda")

pipe.load_lora_weights(
    hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), 
    adapter_name="hyper-sd"
)
pipe.set_adapters("hyper-sd", adapter_weights=0.125)

prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."

image = pipe(
    prompt=prompt,
    height=1024,
    width=1024,
    max_sequence_length=512,
    num_inference_steps=8,
    guidance_scale=50,
    generator=torch.Generator().manual_seed(42),
).images[0]
image[0].save("out.jpg")

Happens on main as well as v0.31.0-release branch as well.

Error
Traceback (most recent call last):
  File "/home/sayak/diffusers/load_loras_flux.py", line 18, in <module>
    pipe.load_lora_weights(
  File "/home/sayak/diffusers/src/diffusers/loaders/lora_pipeline.py", line 1846, in load_lora_weights
    self.load_lora_into_transformer(
  File "/home/sayak/diffusers/src/diffusers/loaders/lora_pipeline.py", line 1948, in load_lora_into_transformer
    inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs)
  File "/home/sayak/.pyenv/versions/3.10.12/envs/diffusers/lib/python3.10/site-packages/peft/mapping.py", line 260, in inject_adapter_in_model
    peft_model = tuner_cls(model, peft_config, adapter_name=adapter_name, low_cpu_mem_usage=low_cpu_mem_usage)
  File "/home/sayak/.pyenv/versions/3.10.12/envs/diffusers/lib/python3.10/site-packages/peft/tuners/lora/model.py", line 141, in __init__
    super().__init__(model, config, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage)
  File "/home/sayak/.pyenv/versions/3.10.12/envs/diffusers/lib/python3.10/site-packages/peft/tuners/tuners_utils.py", line 184, in __init__
    self.inject_adapter(self.model, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage)
  File "/home/sayak/.pyenv/versions/3.10.12/envs/diffusers/lib/python3.10/site-packages/peft/tuners/tuners_utils.py", line 501, in inject_adapter
    self._create_and_replace(peft_config, adapter_name, target, target_name, parent, current_key=key)
  File "/home/sayak/.pyenv/versions/3.10.12/envs/diffusers/lib/python3.10/site-packages/peft/tuners/lora/model.py", line 239, in _create_and_replace
    self._replace_module(parent, target_name, new_module, target)
  File "/home/sayak/.pyenv/versions/3.10.12/envs/diffusers/lib/python3.10/site-packages/peft/tuners/lora/model.py", line 263, in _replace_module
    new_module.to(child.weight.device)
  File "/home/sayak/.pyenv/versions/3.10.12/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1340, in to
    return self._apply(convert)
  File "/home/sayak/.pyenv/versions/3.10.12/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 900, in _apply
    module._apply(fn)
  File "/home/sayak/.pyenv/versions/3.10.12/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 900, in _apply
    module._apply(fn)
  File "/home/sayak/.pyenv/versions/3.10.12/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 927, in _apply
    param_applied = fn(param)
  File "/home/sayak/.pyenv/versions/3.10.12/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1333, in convert
    raise NotImplementedError(
NotImplementedError: Cannot copy out of meta tensor; no data! Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() when moving module from meta to a different device.

@BenjaminBossan any suggestions here?

@BenjaminBossan
Copy link
Member

I investigated and this is indeed a bug in PEFT. It is specific to 8bit bnb. Here is a minimal reproducer:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model

model_id = "facebook/opt-125m"
bnb_config = BitsAndBytesConfig(load_in_8bit=True)  # load_in_4bit works
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config)
model = get_peft_model(model, LoraConfig(), low_cpu_mem_usage=True)

I will work on a fix.

BenjaminBossan added a commit to BenjaminBossan/peft that referenced this issue Jan 13, 2025
There was a bug in PEFT that occurred when trying to use the
low_cpu_mem_usage=True option with 8bit bitsandbytes quantized models.
This bug is fixed now.

See huggingface/diffusers#10550 for the bug
report.
@BenjaminBossan
Copy link
Member

Bugfix: huggingface/peft#2325

I tested it locally with the Flux script posted above and the error disappeared. During the forward pass, I ran into OOM, but the issue itself should be addressed by the linked PR.

@lhjlhj11
Copy link

self.transformer = FluxTransformer2DModel.from_single_file(the flux-dev-fp8 file)
quantize(self.transformer, weights=qfloat8)
freeze(self.transformer)
self.text_encoder_2 = T5EncoderModel.from_pretrained(flux-dev file)
quantize(self.text_encoder_2, weights=qfloat8)
freeze(self.text_encoder_2)

self.pipe = FluxPipeline.from_pretrained(flux file, transformer=None, text_encoder_2=None, torch_dtype=torch.bfloat16).to(self.device)

self.pipe.transformer = self.transformer
self.pipe.text_encoder_2 = self.text_encoder_2

self.pipe.load_lora_weights("ByteDance/Hyper-SD", device=self.device), adapter_name="8steps")
self.pipe.set_adapters(["8steps"], adapter_weights=[0.125])

this way to load fp8 model and lora still has problem:

Traceback (most recent call last):
  File "/home/iv/anaconda3/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/home/iv/anaconda3/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/iv/Algo_new/LouHaijie/IVAlgoHTTP/IVAlgoHTTP/app.py", line 112, in process_data
    IVmodelHandle.loadModel()
  File "/home/iv/Algo_new/LouHaijie/IVAlgoHTTP/IVAlgoHTTP/algos/BaseModel.py", line 28, in loadModel
    self.model = GlobalConfig.get_model()
  File "/home/iv/Algo_new/LouHaijie/IVAlgoHTTP/IVAlgoHTTP/algos/GlobalConfig.py", line 76, in get_model
    modelHandler = model()
  File "/home/iv/Algo_new/LouHaijie/IVAlgoHTTP/IVAlgoHTTP/algos/FluxTxt2Img/Flux_txt2img.py", line 131, in __init__
    self.pipe.load_lora_weights(load_file(os.path.join(self.model_root, self.config["8steps_lora"]), device=self.device), adapter_name="8steps")
  File "/home/iv/Algo_new/LouHaijie/IVAlgoHTTP/66v/lib/python3.10/site-packages/diffusers/loaders/lora_pipeline.py", line 1559, in load_lora_weights
    transformer_lora_state_dict = self._maybe_expand_lora_state_dict(
  File "/home/iv/Algo_new/LouHaijie/IVAlgoHTTP/66v/lib/python3.10/site-packages/diffusers/loaders/lora_pipeline.py", line 2080, in _maybe_expand_lora_state_dict
    base_weight_param = transformer_state_dict[base_param_name]
KeyError: 'context_embedder.weight'

and

Traceback (most recent call last):
  File "/home/iv/anaconda3/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/home/iv/anaconda3/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/iv/Algo_new/LouHaijie/IVAlgoHTTP/IVAlgoHTTP/app.py", line 112, in process_data
    IVmodelHandle.loadModel()
  File "/home/iv/Algo_new/LouHaijie/IVAlgoHTTP/IVAlgoHTTP/algos/BaseModel.py", line 28, in loadModel
    self.model = GlobalConfig.get_model()
  File "/home/iv/Algo_new/LouHaijie/IVAlgoHTTP/IVAlgoHTTP/algos/GlobalConfig.py", line 76, in get_model
    modelHandler = model()
  File "/home/iv/Algo_new/LouHaijie/IVAlgoHTTP/IVAlgoHTTP/algos/FluxTxt2Img/Flux_txt2img.py", line 143, in __init__
    self.pipe.load_lora_weights(load_file(os.path.join(self.model_root, self.config["Flux"][style]["lora_repo"][0]), device=self.device), adapter_name=style)
  File "/home/iv/Algo_new/LouHaijie/IVAlgoHTTP/66v/lib/python3.10/site-packages/diffusers/loaders/lora_pipeline.py", line 1559, in load_lora_weights
    transformer_lora_state_dict = self._maybe_expand_lora_state_dict(
  File "/home/iv/Algo_new/LouHaijie/IVAlgoHTTP/66v/lib/python3.10/site-packages/diffusers/loaders/lora_pipeline.py", line 2080, in _maybe_expand_lora_state_dict
    base_weight_param = transformer_state_dict[base_param_name]
KeyError: 'single_transformer_blocks.0.attn.to_k.weight'

@sayakpaul
Copy link
Member Author

File a separate issue without any self and make it FULLY REPRODUCIBLE. We cannot be expected to download additional files for which we don't have any instructions or whatsoever. I have said this multiple times to you and you seem to be completely disregarding it. Please either provide a minimal yet fully reproducible snippet or we will have to stop engaging.

@lhjlhj11
Copy link

lhjlhj11 commented Jan 14, 2025

File a separate issue without any self and make it FULLY REPRODUCIBLE. We cannot be expected to download additional files for which we don't have any instructions or whatsoever. I have said this multiple times to you and you seem to be completely disregarding it. Please either provide a minimal yet fully reproducible snippet or we will have to stop engaging.

Sorry, I just think my code is very simple, these are the reproducible codes.

bfl_repo = "black-forest-labs/FLUX.1-dev"
dtype = torch.bfloat16

transformer = FluxTransformer2DModel.from_single_file("https://huggingface.co/Kijai/flux-fp8/blob/main/flux1-dev-fp8.safetensors", torch_dtype=dtype)
quantize(transformer, weights=qfloat8)
freeze(transformer)

text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype)
quantize(text_encoder_2, weights=qfloat8)
freeze(text_encoder_2)

pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=None, text_encoder_2=None, torch_dtype=dtype)
pipe.transformer = transformer
pipe.text_encoder_2 = text_encoder_2

pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"),adapter_name="hyper-sd")
pipe.set_adapters("hyper-sd", adapter_weights=0.125)

this will lead to:

Traceback (most recent call last):
  File "/home/iv/anaconda3/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/home/iv/anaconda3/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/iv/Algo_new/LouHaijie/IVAlgoHTTP/IVAlgoHTTP/app.py", line 112, in process_data
    IVmodelHandle.loadModel()
  File "/home/iv/Algo_new/LouHaijie/IVAlgoHTTP/IVAlgoHTTP/algos/BaseModel.py", line 28, in loadModel
    self.model = GlobalConfig.get_model()
  File "/home/iv/Algo_new/LouHaijie/IVAlgoHTTP/IVAlgoHTTP/algos/GlobalConfig.py", line 76, in get_model
    modelHandler = model()
  File "/home/iv/Algo_new/LouHaijie/IVAlgoHTTP/IVAlgoHTTP/algos/FluxTxt2Img/Flux_txt2img.py", line 131, in __init__
    self.pipe.load_lora_weights(load_file(os.path.join(self.model_root, self.config["8steps_lora"]), device=self.device), adapter_name="8steps")
  File "/home/iv/Algo_new/LouHaijie/IVAlgoHTTP/66v/lib/python3.10/site-packages/diffusers/loaders/lora_pipeline.py", line 1559, in load_lora_weights
    transformer_lora_state_dict = self._maybe_expand_lora_state_dict(
  File "/home/iv/Algo_new/LouHaijie/IVAlgoHTTP/66v/lib/python3.10/site-packages/diffusers/loaders/lora_pipeline.py", line 2080, in _maybe_expand_lora_state_dict
    base_weight_param = transformer_state_dict[base_param_name]
KeyError: 'context_embedder.weight'

and when I load other lora that is downloaded from https://civitai.com/, this will lead to:

Traceback (most recent call last):
  File "/home/iv/anaconda3/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/home/iv/anaconda3/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/iv/Algo_new/LouHaijie/IVAlgoHTTP/IVAlgoHTTP/app.py", line 112, in process_data
    IVmodelHandle.loadModel()
  File "/home/iv/Algo_new/LouHaijie/IVAlgoHTTP/IVAlgoHTTP/algos/BaseModel.py", line 28, in loadModel
    self.model = GlobalConfig.get_model()
  File "/home/iv/Algo_new/LouHaijie/IVAlgoHTTP/IVAlgoHTTP/algos/GlobalConfig.py", line 76, in get_model
    modelHandler = model()
  File "/home/iv/Algo_new/LouHaijie/IVAlgoHTTP/IVAlgoHTTP/algos/FluxTxt2Img/Flux_txt2img.py", line 143, in __init__
    self.pipe.load_lora_weights(load_file(os.path.join(self.model_root, self.config["Flux"][style]["lora_repo"][0]), device=self.device), adapter_name=style)
  File "/home/iv/Algo_new/LouHaijie/IVAlgoHTTP/66v/lib/python3.10/site-packages/diffusers/loaders/lora_pipeline.py", line 1559, in load_lora_weights
    transformer_lora_state_dict = self._maybe_expand_lora_state_dict(
  File "/home/iv/Algo_new/LouHaijie/IVAlgoHTTP/66v/lib/python3.10/site-packages/diffusers/loaders/lora_pipeline.py", line 2080, in _maybe_expand_lora_state_dict
    base_weight_param = transformer_state_dict[base_param_name]
KeyError: 'single_transformer_blocks.0.attn.to_k.weight'

@lhjlhj11
Copy link

lhjlhj11 commented Jan 14, 2025

And then, I try another ways to load model:

transformer_8bit = FluxTransformer2DModel.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    subfolder="transformer",
    quantization_config=DiffusersBitsAndBytesConfig(load_in_8bit=True),
    torch_dtype=torch.bfloat16,
)
pipe = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    transformer=transformer_8bit,
    torch_dtype=torch.bfloat16,
).to("cuda")

pipe.load_lora_weights(
    hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), 
    adapter_name="hyper-sd"
)
pipe.set_adapters("hyper-sd", adapter_weights=0.125)

this leads to:

Traceback (most recent call last):
  File "/home/iv/Algo_new/LouHaijie/IVAlgoHTTP/66v/lib/python3.10/site-packages/diffusers/loaders/peft.py", line 301, in load_lora_adapter
    inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
  File "/home/iv/Algo_new/LouHaijie/IVAlgoHTTP/66v/lib/python3.10/site-packages/peft/mapping.py", line 260, in inject_adapter_in_model
    peft_model = tuner_cls(model, peft_config, adapter_name=adapter_name, low_cpu_mem_usage=low_cpu_mem_usage)
  File "/home/iv/Algo_new/LouHaijie/IVAlgoHTTP/66v/lib/python3.10/site-packages/peft/tuners/lora/model.py", line 141, in __init__
    super().__init__(model, config, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage)
  File "/home/iv/Algo_new/LouHaijie/IVAlgoHTTP/66v/lib/python3.10/site-packages/peft/tuners/tuners_utils.py", line 184, in __init__
    self.inject_adapter(self.model, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage)
  File "/home/iv/Algo_new/LouHaijie/IVAlgoHTTP/66v/lib/python3.10/site-packages/peft/tuners/tuners_utils.py", line 501, in inject_adapter
    self._create_and_replace(peft_config, adapter_name, target, target_name, parent, current_key=key)
  File "/home/iv/Algo_new/LouHaijie/IVAlgoHTTP/66v/lib/python3.10/site-packages/peft/tuners/lora/model.py", line 239, in _create_and_replace
    self._replace_module(parent, target_name, new_module, target)
  File "/home/iv/Algo_new/LouHaijie/IVAlgoHTTP/66v/lib/python3.10/site-packages/peft/tuners/lora/model.py", line 263, in _replace_module
    new_module.to(child.weight.device)
  File "/home/iv/Algo_new/LouHaijie/IVAlgoHTTP/66v/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1174, in to
    return self._apply(convert)
  File "/home/iv/Algo_new/LouHaijie/IVAlgoHTTP/66v/lib/python3.10/site-packages/torch/nn/modules/module.py", line 780, in _apply
    module._apply(fn)
  File "/home/iv/Algo_new/LouHaijie/IVAlgoHTTP/66v/lib/python3.10/site-packages/torch/nn/modules/module.py", line 780, in _apply
    module._apply(fn)
  File "/home/iv/Algo_new/LouHaijie/IVAlgoHTTP/66v/lib/python3.10/site-packages/torch/nn/modules/module.py", line 805, in _apply
    param_applied = fn(param)
  File "/home/iv/Algo_new/LouHaijie/IVAlgoHTTP/66v/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1167, in convert
    raise NotImplementedError(
NotImplementedError: Cannot copy out of meta tensor; no data! Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() when moving module from meta to a different device.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/iv/anaconda3/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/home/iv/anaconda3/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/iv/Algo_new/LouHaijie/IVAlgoHTTP/IVAlgoHTTP/app.py", line 112, in process_data
    IVmodelHandle.loadModel()
  File "/home/iv/Algo_new/LouHaijie/IVAlgoHTTP/IVAlgoHTTP/algos/BaseModel.py", line 28, in loadModel
    self.model = GlobalConfig.get_model()
  File "/home/iv/Algo_new/LouHaijie/IVAlgoHTTP/IVAlgoHTTP/algos/GlobalConfig.py", line 76, in get_model
    modelHandler = model()
  File "/home/iv/Algo_new/LouHaijie/IVAlgoHTTP/IVAlgoHTTP/algos/FluxTxt2Img/Flux_txt2img.py", line 141, in __init__
    self.pipe.load_lora_weights(load_file(os.path.join(self.model_root, self.config["8steps_lora"]), device=self.device), adapter_name="8steps")
  File "/home/iv/Algo_new/LouHaijie/IVAlgoHTTP/66v/lib/python3.10/site-packages/diffusers/loaders/lora_pipeline.py", line 1564, in load_lora_weights
    self.load_lora_into_transformer(
  File "/home/iv/Algo_new/LouHaijie/IVAlgoHTTP/66v/lib/python3.10/site-packages/diffusers/loaders/lora_pipeline.py", line 1628, in load_lora_into_transformer
    transformer.load_lora_adapter(
  File "/home/iv/Algo_new/LouHaijie/IVAlgoHTTP/66v/lib/python3.10/site-packages/diffusers/loaders/peft.py", line 311, in load_lora_adapter
    self.peft_config.pop(adapter_name)
  File "/home/iv/Algo_new/LouHaijie/IVAlgoHTTP/66v/lib/python3.10/site-packages/diffusers/models/modeling_utils.py", line 173, in __getattr__
    return super().__getattr__(name)
  File "/home/iv/Algo_new/LouHaijie/IVAlgoHTTP/66v/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1729, in __getattr__
    raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'FluxTransformer2DModel' object has no attribute 'peft_config'. Did you mean: 'from_config'?

@sayakpaul
Copy link
Member Author

sayakpaul commented Jan 14, 2025

And then, I try another ways to load model:

That error will go away if you install peft from this branch: huggingface/peft#2325.

Regarding the snippet in #10550 (comment), I think this is a separate problem stemming from quanto as it updates the state dict keys such as:

['time_text_embed.timestep_embedder.linear_1.weight._data', 'time_text_embed.timestep_embedder.linear_1.weight._scale', 'time_text_embed.timestep_embedder.linear_1.bias', 'time_text_embed.timestep_embedder.linear_1.input_scale', 'time_text_embed.timestep_embedder.linear_1.output_scale']

Notice the ._data, _scale, .input_scale, .output_scale attributes.

@BenjaminBossan I know this doesn't directly correspond to peft but I wanted your opinions on how to best tackle this.

Minimally reproducible snippet:
import torch 
from diffusers import FluxTransformer2DModel, FluxPipeline
from huggingface_hub import hf_hub_download
from optimum.quanto import qfloat8, quantize, freeze
from transformers import T5EncoderModel

bfl_repo = "black-forest-labs/FLUX.1-dev"
dtype = torch.bfloat16

transformer = FluxTransformer2DModel.from_single_file("https://huggingface.co/Kijai/flux-fp8/blob/main/flux1-dev-fp8.safetensors", torch_dtype=dtype)
quantize(transformer, weights=qfloat8)
freeze(transformer)

text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype)
quantize(text_encoder_2, weights=qfloat8)
freeze(text_encoder_2)

pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=None, text_encoder_2=None, torch_dtype=dtype)
pipe.transformer = transformer
pipe.text_encoder_2 = text_encoder_2

pipe.load_lora_weights(
    hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd"
)

It leads to:

Traceback (most recent call last):
  File "/home/sayak/diffusers/check_fp8.py", line 22, in <module>
    pipe.load_lora_weights(
  File "/home/sayak/diffusers/src/diffusers/loaders/lora_pipeline.py", line 1559, in load_lora_weights
    transformer_lora_state_dict = self._maybe_expand_lora_state_dict(
  File "/home/sayak/diffusers/src/diffusers/loaders/lora_pipeline.py", line 2081, in _maybe_expand_lora_state_dict
    base_weight_param = transformer_state_dict[base_param_name]
KeyError: 'context_embedder.weight'

In my mind, I can think of adding additional checks here in:

base_weight_param = transformer_state_dict[base_param_name]

But perhaps some kind of a dispatcher would make more sense here.

EDIT: I investigated a bit and I think even if we determine the base_param_name correctly based on Quanto, there needs to be support for it from peft.

Error
Traceback (most recent call last):
  File "/home/sayak/diffusers/check_fp8.py", line 22, in <module>
    pipe.load_lora_weights(
  File "/home/sayak/diffusers/src/diffusers/loaders/lora_pipeline.py", line 1564, in load_lora_weights
    self.load_lora_into_transformer(
  File "/home/sayak/diffusers/src/diffusers/loaders/lora_pipeline.py", line 1628, in load_lora_into_transformer
    transformer.load_lora_adapter(
  File "/home/sayak/diffusers/src/diffusers/loaders/peft.py", line 302, in load_lora_adapter
    incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
  File "/home/sayak/peft/src/peft/utils/save_and_load.py", line 445, in set_peft_model_state_dict
    load_result = model.load_state_dict(peft_model_state_dict, strict=False, assign=True)
  File "/home/sayak/.pyenv/versions/3.10.12/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2564, in load_state_dict
    load(self, state_dict)
  File "/home/sayak/.pyenv/versions/3.10.12/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2552, in load
    load(child, child_state_dict, child_prefix)  # noqa: F821
  File "/home/sayak/.pyenv/versions/3.10.12/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2552, in load
    load(child, child_state_dict, child_prefix)  # noqa: F821
  File "/home/sayak/.pyenv/versions/3.10.12/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2552, in load
    load(child, child_state_dict, child_prefix)  # noqa: F821
  [Previous line repeated 1 more time]
  File "/home/sayak/.pyenv/versions/3.10.12/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2535, in load
    module._load_from_state_dict(
  File "/home/sayak/.pyenv/versions/3.10.12/envs/diffusers/lib/python3.10/site-packages/optimum/quanto/nn/qmodule.py", line 160, in _load_from_state_dict
    deserialized_weight = WeightQBytesTensor.load_from_state_dict(
  File "/home/sayak/.pyenv/versions/3.10.12/envs/diffusers/lib/python3.10/site-packages/optimum/quanto/tensor/weights/qbytes.py", line 77, in load_from_state_dict
    inner_tensors_dict[name] = state_dict.pop(prefix + name)
KeyError: 'time_text_embed.timestep_embedder.linear_1.base_layer.weight._data'
Changes
diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py
index 7492ba028..812a5c883 100644
--- a/src/diffusers/loaders/lora_pipeline.py
+++ b/src/diffusers/loaders/lora_pipeline.py
@@ -2053,6 +2053,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
 
     @classmethod
     def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict):
+        from optimum.quanto import QLinear
         expanded_module_names = set()
         transformer_state_dict = transformer.state_dict()
         prefix = f"{cls.transformer_name}."
@@ -2068,6 +2069,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
             logger.debug(f"Found unexpected modules: {unexpected_modules}. These will be ignored.")
 
         is_peft_loaded = getattr(transformer, "peft_config", None) is not None
+        is_quanto_quantized = any(isinstance(m, QLinear) for _, m in transformer.named_modules())
         for k in lora_module_names:
             if k in unexpected_modules:
                 continue
@@ -2077,6 +2079,8 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
                 if is_peft_loaded and f"{k.replace(prefix, '')}.base_layer.weight" in transformer_state_dict
                 else f"{k.replace(prefix, '')}.weight"
             )
+            if is_quanto_quantized:
+                base_param_name += "._data"
             base_weight_param = transformer_state_dict[base_param_name]
             lora_A_param = lora_state_dict[f"{prefix}{k}.lora_A.weight"]

@lhjlhj11
Copy link

And then, I try another ways to load model:

That error will go away if you install peft from this branch: huggingface/peft#2325.

Regarding the snippet in #10550 (comment), I think this is a separate problem stemming from quanto as it updates the state dict keys such as:

['time_text_embed.timestep_embedder.linear_1.weight._data', 'time_text_embed.timestep_embedder.linear_1.weight._scale', 'time_text_embed.timestep_embedder.linear_1.bias', 'time_text_embed.timestep_embedder.linear_1.input_scale', 'time_text_embed.timestep_embedder.linear_1.output_scale']

Notice the ._data, _scale, .input_scale, .output_scale attributes.

@BenjaminBossan I know this doesn't directly correspond to peft but I wanted your opinions on how to best tackle this.
Minimally reproducible snippet:

It leads to:

Traceback (most recent call last):
File "/home/sayak/diffusers/check_fp8.py", line 22, in
pipe.load_lora_weights(
File "/home/sayak/diffusers/src/diffusers/loaders/lora_pipeline.py", line 1559, in load_lora_weights
transformer_lora_state_dict = self._maybe_expand_lora_state_dict(
File "/home/sayak/diffusers/src/diffusers/loaders/lora_pipeline.py", line 2081, in _maybe_expand_lora_state_dict
base_weight_param = transformer_state_dict[base_param_name]
KeyError: 'context_embedder.weight'

In my mind, I can think of adding additional checks here in:

diffusers/src/diffusers/loaders/lora_pipeline.py

Line 2080 in 74b6752
base_weight_param = transformer_state_dict[base_param_name]

But perhaps some kind of a dispatcher would make more sense here.

EDIT: I investigated a bit and I think even if we determine the base_param_name correctly based on Quanto, there needs to be support for it from peft.
Error

Changes

Even after I installed the peft by
pip install git+https://github.com/huggingface/peft
the problem is still here.

accelerate               1.2.1
annotated-types          0.7.0
anyio                    4.8.0
asttokens                3.0.0
bitsandbytes             0.45.0
boto3                    1.35.97
botocore                 1.35.97
certifi                  2024.12.14
charset-normalizer       3.4.1
click                    8.1.8
coloredlogs              15.0.1
comm                     0.2.2
compel                   2.0.3
debugpy                  1.8.11
decorator                5.1.1
diffusers                0.33.0.dev0
einops                   0.8.0
exceptiongroup           1.2.2
executing                2.1.0
fastapi                  0.115.6
ffmpeg                   1.4
filelock                 3.13.1
flatbuffers              24.12.23
fsspec                   2024.2.0
h11                      0.14.0
huggingface-hub          0.27.0
humanfriendly            10.0
idna                     3.10
importlib_metadata       8.5.0
ipykernel                6.29.5
ipython                  8.31.0
ipywidgets               8.1.5
jedi                     0.19.2
Jinja2                   3.1.3
jmespath                 1.0.1
jupyter_client           8.6.3
jupyter_core             5.7.2
jupyterlab_widgets       3.0.13
lark                     1.2.2
MarkupSafe               2.1.5
matplotlib-inline        0.1.7
mpmath                   1.3.0
nest-asyncio             1.6.0
networkx                 3.2.1
ninja                    1.11.1.3
numpy                    1.26.3
nvidia-cublas-cu12       12.4.2.65
nvidia-cuda-cupti-cu12   12.4.99
nvidia-cuda-nvrtc-cu12   12.4.99
nvidia-cuda-runtime-cu12 12.4.99
nvidia-cudnn-cu12        9.1.0.70
nvidia-cufft-cu12        11.2.0.44
nvidia-curand-cu12       10.3.5.119
nvidia-cusolver-cu12     11.6.0.99
nvidia-cusparse-cu12     12.3.0.142
nvidia-nccl-cu12         2.20.5
nvidia-nvjitlink-cu12    12.4.99
nvidia-nvtx-cu12         12.4.99
onnxruntime-gpu          1.20.1
opencv-python            4.10.0.84
optimum-quanto           0.2.5
packaging                24.2
parso                    0.8.4
peft                     0.14.1.dev0
pexpect                  4.9.0
pillow                   10.2.0
pip                      23.0.1
platformdirs             4.3.6
prompt_toolkit           3.0.48
protobuf                 5.29.2
psutil                   6.1.1
ptyprocess               0.7.0
pure_eval                0.2.3
pycryptodome             3.21.0
pydantic                 2.10.5
pydantic_core            2.27.2
Pygments                 2.18.0
PyMySQL                  1.1.1
pyparsing                3.2.0
python-dateutil          2.9.0.post0
PyYAML                   6.0.2
pyzmq                    26.2.0
regex                    2024.11.6
requests                 2.32.3
s3transfer               0.10.4
safetensors              0.4.5
schedule                 1.2.2
sd-embed                 1.240829.1
sentencepiece            0.2.0
setuptools               65.5.0
six                      1.17.0
sniffio                  1.3.1
stack-data               0.6.3
starlette                0.41.3
sympy                    1.13.1
tenacity                 9.0.0
timm                     0.6.7
tokenizers               0.21.0
torch                    2.4.1+cu124
torchaudio               2.4.1+cu124
torchvision              0.19.1+cu124
tornado                  6.4.2
tqdm                     4.67.1
traitlets                5.14.3
transformers             4.47.1
triton                   3.0.0
typing_extensions        4.12.2
urllib3                  2.3.0
uvicorn                  0.34.0
wcwidth                  0.2.13
widgetsnbextension       4.0.13
zipp                     3.21.0

@sayakpaul
Copy link
Member Author

Even after I installed the peft by pip install git+https://github.com/huggingface/peft the problem is still here.

I said:

That error will go away if you install peft from this branch: huggingface/peft#2325.

This means you need to do: pip install git+https://github.com/BenjaminBossan/peft@fix-low-cpu-mem-usage-bnb-8bit

@lhjlhj11
Copy link

Even after I installed the peft by pip install git+https://github.com/huggingface/peft the problem is still here.

I said:

That error will go away if you install peft from this branch: huggingface/peft#2325.

This means you need to do: pip install git+https://github.com/BenjaminBossan/peft@fix-low-cpu-mem-usage-bnb-8bit

ok, sorry for my careless and very very thanks for your help!!!

@lhjlhj11
Copy link

Even after I installed the peft by pip install git+https://github.com/huggingface/peft the problem is still here.

I said:

That error will go away if you install peft from this branch: huggingface/peft#2325.

This means you need to do: pip install git+https://github.com/BenjaminBossan/peft@fix-low-cpu-mem-usage-bnb-8bit

ok, sorry for my careless and very very thanks for your help!!!

But I found that this way slows down the sampling speed by nearly 1.5 times white the optimum does not

@sayakpaul
Copy link
Member Author

That is a separate issue and we can address that in #10550 (comment)

@lhjlhj11
Copy link

That is a separate issue and we can address that in #10550 (comment)

Yes, the problem caused by Hyper-FLUX.1-dev-8steps-lora.safetensors can be solved by this way.
However, another question I have metioned is even here. when I load other lora that is downloaded from https://civitai.com/, this will lead to error:
For Example, I download a lora from huggingface randomly.

import torch 
from diffusers import FluxTransformer2DModel, FluxPipeline
from huggingface_hub import hf_hub_download
from optimum.quanto import qfloat8, quantize, freeze
from transformers import T5EncoderModel

bfl_repo = "black-forest-labs/FLUX.1-dev"
dtype = torch.bfloat16

transformer = FluxTransformer2DModel.from_single_file("https://huggingface.co/Kijai/flux-fp8/blob/main/flux1-dev-fp8.safetensors", torch_dtype=dtype)
quantize(transformer, weights=qfloat8)
freeze(transformer)

text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype)
quantize(text_encoder_2, weights=qfloat8)
freeze(text_encoder_2)

pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=None, text_encoder_2=None, torch_dtype=dtype)
pipe.transformer = transformer
pipe.text_encoder_2 = text_encoder_2

pipe.load_lora_weights(
    hf_hub_download(hf_hub_download("Shakker-Labs/FLUX.1-dev-LoRA-Logo-Design", "FLUX-dev-lora-Logo-Design.safetensors"), adapter_name="logo"
)

This would lead to:

Traceback (most recent call last):
  File "/home/iv/anaconda3/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/home/iv/anaconda3/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/iv/Algo_new/LouHaijie/IVAlgoHTTP/IVAlgoHTTP/app.py", line 112, in process_data
    IVmodelHandle.loadModel()
  File "/home/iv/Algo_new/LouHaijie/IVAlgoHTTP/IVAlgoHTTP/algos/BaseModel.py", line 28, in loadModel
    self.model = GlobalConfig.get_model()
  File "/home/iv/Algo_new/LouHaijie/IVAlgoHTTP/IVAlgoHTTP/algos/GlobalConfig.py", line 76, in get_model
    modelHandler = model()
  File "/home/iv/Algo_new/LouHaijie/IVAlgoHTTP/IVAlgoHTTP/algos/FluxTxt2Img/Flux_txt2img.py", line 145, in __init__
    self.pipe.load_lora_weights(hf_hub_download("Shakker-Labs/FLUX.1-dev-LoRA-Logo-Design", "FLUX-dev-lora-Logo-Design.safetensors"), adapter_name="logo")
  File "/home/iv/Algo_new/LouHaijie/IVAlgoHTTP/66v/lib/python3.10/site-packages/diffusers/loaders/lora_pipeline.py", line 1559, in load_lora_weights
    transformer_lora_state_dict = self._maybe_expand_lora_state_dict(
  File "/home/iv/Algo_new/LouHaijie/IVAlgoHTTP/66v/lib/python3.10/site-packages/diffusers/loaders/lora_pipeline.py", line 2085, in _maybe_expand_lora_state_dict
    base_weight_param = transformer_state_dict[base_param_name]
KeyError: 'single_transformer_blocks.0.attn.to_k.weight._data'

@sayakpaul
Copy link
Member Author

I think you misunderstood what I said in #10550 (comment) regarding Quanto support. It's not supported yet. In fact the error you posted is already there in #10550 (comment).

Could you try to please read the things carefully before repeatedly commenting? This has happened multiple times now in different threads.

@lhjlhj11
Copy link

I think you misunderstood what I said in #10550 (comment) regarding Quanto support. It's not supported yet. In fact the error you posted is already there in #10550 (comment).

Could you try to please read the things carefully before repeatedly commenting? This has happened multiple times now in different threads.

Sorry for that, may be my English reading ability needs to improved.

@BenjaminBossan
Copy link
Member

@sayakpaul I haven't checked the details of your quanto issue yet, but note that at the moment, quanto is not supported in PEFT. There is a PR, huggingface/peft#2000, which has become a bit stale at this point, but I plan to finish it soon🤞.

@lhjlhj11 Just an update, the PR with the fix is merged to PEFT, so you can now install it from the main branch instead of https://github.com/BenjaminBossan/peft@fix-low-cpu-mem-usage-bnb-8bit.

@sayakpaul
Copy link
Member Author

Issue has been fixed for bitsandbytes quantized models. Quanto support is contingent on peft (and likely not a priority).

We still have to fix the case for Control LoRAs from FLUX. Will open a separate issue to track and work on it.

@Jay-9-c
Copy link

Jay-9-c commented Jan 16, 2025

Issue has been fixed for bitsandbytes quantized models. Quanto support is contingent on peft (and likely not a priority).

We still have to fix the case for Control LoRAs from FLUX. Will open a separate issue to track and work on it.

Does the speed of sampling influenced by bitsandbytes?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

4 participants