Skip to content

Commit 71de5b7

Browse files
[LoRA] quality of life improvements in the loading semantics and docs (#3180)
* 👽 qol improvements for LoRA. * better function name? * fix: LoRA weight loading with the new format. * address Patrick's comments. * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> * change wording around encouraging the use of load_lora_weights(). * fix: function name. --------- Co-authored-by: Patrick von Platen <[email protected]>
1 parent 256e696 commit 71de5b7

File tree

7 files changed

+123
-27
lines changed

7 files changed

+123
-27
lines changed

docs/source/en/_toctree.yml

+3-3
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@
171171
- local: api/pipelines/semantic_stable_diffusion
172172
title: Semantic Guidance
173173
- local: api/pipelines/spectrogram_diffusion
174-
title: "Spectrogram Diffusion"
174+
title: Spectrogram Diffusion
175175
- sections:
176176
- local: api/pipelines/stable_diffusion/overview
177177
title: Overview
@@ -238,6 +238,8 @@
238238
title: DPM Discrete Scheduler
239239
- local: api/schedulers/dpm_discrete_ancestral
240240
title: DPM Discrete Scheduler with ancestral sampling
241+
- local: api/schedulers/dpm_sde
242+
title: DPMSolverSDEScheduler
241243
- local: api/schedulers/euler_ancestral
242244
title: Euler Ancestral Scheduler
243245
- local: api/schedulers/euler
@@ -266,8 +268,6 @@
266268
title: VP-SDE
267269
- local: api/schedulers/vq_diffusion
268270
title: VQDiffusionScheduler
269-
- local: api/schedulers/dpm_sde
270-
title: DPMSolverSDEScheduler
271271
title: Schedulers
272272
- sections:
273273
- local: api/experimental/rl

docs/source/en/training/lora.mdx

+33-3
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ Load the LoRA weights from your finetuned model *on top of the base model weight
115115
</Tip>
116116

117117
```py
118-
>>> pipe.unet.load_attn_procs(model_path)
118+
>>> pipe.unet.load_attn_procs(lora_model_path)
119119
>>> pipe.to("cuda")
120120
# use half the weights from the LoRA finetuned model and half the weights from the base model
121121

@@ -128,6 +128,25 @@ Load the LoRA weights from your finetuned model *on top of the base model weight
128128
>>> image.save("blue_pokemon.png")
129129
```
130130

131+
<Tip>
132+
133+
If you are loading the LoRA parameters from the Hub and if the Hub repository has
134+
a `base_model` tag (such as [this](https://huggingface.co/sayakpaul/sd-model-finetuned-lora-t4/blob/main/README.md?code=true#L4)), then
135+
you can do:
136+
137+
```py
138+
from huggingface_hub.repocard import RepoCard
139+
140+
lora_model_id = "sayakpaul/sd-model-finetuned-lora-t4"
141+
card = RepoCard.load(lora_model_id)
142+
base_model_id = card.data.to_dict()["base_model"]
143+
144+
pipe = StableDiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16)
145+
...
146+
```
147+
148+
</Tip>
149+
131150
## DreamBooth
132151

133152
[DreamBooth](https://arxiv.org/abs/2208.12242) is a finetuning technique for personalizing a text-to-image model like Stable Diffusion to generate photorealistic images of a subject in different contexts, given a few images of the subject. However, DreamBooth is very sensitive to hyperparameters and it is easy to overfit. Some important hyperparameters to consider include those that affect the training time (learning rate, number of training steps), and inference time (number of steps, scheduler type).
@@ -208,7 +227,7 @@ Load the LoRA weights from your finetuned DreamBooth model *on top of the base m
208227
</Tip>
209228

210229
```py
211-
>>> pipe.unet.load_attn_procs(model_path)
230+
>>> pipe.unet.load_attn_procs(lora_model_path)
212231
>>> pipe.to("cuda")
213232
# use half the weights from the LoRA finetuned model and half the weights from the base model
214233

@@ -222,4 +241,15 @@ Load the LoRA weights from your finetuned DreamBooth model *on top of the base m
222241

223242
>>> image = pipe("A picture of a sks dog in a bucket.", num_inference_steps=25, guidance_scale=7.5).images[0]
224243
>>> image.save("bucket-dog.png")
225-
```
244+
```
245+
246+
Note that the use of [`LoraLoaderMixin.load_lora_weights`] is preferred to [`UNet2DConditionLoadersMixin.load_attn_procs`] for loading LoRA parameters. This is because
247+
[`LoraLoaderMixin.load_lora_weights`] can handle the following situations:
248+
249+
* LoRA parameters that don't have separate identifiers for the UNet and the text encoder (such as [`"patrickvonplaten/lora_dreambooth_dog_example"`](https://huggingface.co/patrickvonplaten/lora_dreambooth_dog_example)). So, you can just do:
250+
251+
```py
252+
pipe.load_lora_weights(lora_model_path)
253+
```
254+
255+
* LoRA parameters that have separate identifiers for the UNet and the text encoder such as: [`"sayakpaul/dreambooth"`](https://huggingface.co/sayakpaul/dreambooth).

examples/dreambooth/README.md

+28-1
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ The final LoRA embedding weights have been uploaded to [patrickvonplaten/lora_dr
355355
The training results are summarized [here](https://api.wandb.ai/report/patrickvonplaten/xm6cd5q5).
356356
You can use the `Step` slider to see how the model learned the features of our subject while the model trained.
357357

358-
Optionally, we can also train additional LoRA layers for the text encoder. Specify the `train_text_encoder` argument above for that. If you're interested to know more about how we
358+
Optionally, we can also train additional LoRA layers for the text encoder. Specify the `--train_text_encoder` argument above for that. If you're interested to know more about how we
359359
enable this support, check out this [PR](https://github.com/huggingface/diffusers/pull/2918).
360360

361361
With the default hyperparameters from the above, the training seems to go in a positive direction. Check out [this panel](https://wandb.ai/sayakpaul/dreambooth-lora/reports/test-23-04-17-17-00-13---Vmlldzo0MDkwNjMy). The trained LoRA layers are available [here](https://huggingface.co/sayakpaul/dreambooth).
@@ -387,6 +387,33 @@ Finally, we can run the model in inference.
387387
image = pipe("A picture of a sks dog in a bucket", num_inference_steps=25).images[0]
388388
```
389389

390+
If you are loading the LoRA parameters from the Hub and if the Hub repository has
391+
a `base_model` tag (such as [this](https://huggingface.co/patrickvonplaten/lora_dreambooth_dog_example/blob/main/README.md?code=true#L4)), then
392+
you can do:
393+
394+
```py
395+
from huggingface_hub.repocard import RepoCard
396+
397+
lora_model_id = "patrickvonplaten/lora_dreambooth_dog_example"
398+
card = RepoCard.load(lora_model_id)
399+
base_model_id = card.data.to_dict()["base_model"]
400+
401+
pipe = StableDiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16)
402+
...
403+
```
404+
405+
**Note** that we will gradually be depcrecating the use of [`UNet2DConditionLoadersMixin.load_attn_procs`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.UNet2DConditionLoadersMixin.load_attn_procs) since we now have a more general
406+
method to load the LoRA parameters -- [`LoraLoaderMixin.load_lora_weights`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.load_lora_weights). This is because
407+
[`LoraLoaderMixin.load_lora_weights`] can handle the following situations:
408+
409+
* LoRA parameters that don't have separate identifiers for the UNet and the text encoder (such as [`"patrickvonplaten/lora_dreambooth_dog_example"`](https://huggingface.co/patrickvonplaten/lora_dreambooth_dog_example)). So, you can just do:
410+
411+
```py
412+
pipe.load_lora_weights(lora_model_path)
413+
```
414+
415+
* LoRA parameters that have separate identifiers for the UNet and the text encoder such as: [`"sayakpaul/dreambooth"`](https://huggingface.co/sayakpaul/dreambooth).
416+
390417
## Training with Flax/JAX
391418

392419
For faster training on TPUs and GPUs you can leverage the flax training example. Follow the instructions above to get the model and dataset before running the script.

examples/dreambooth/train_dreambooth_lora.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1045,7 +1045,7 @@ def main(args):
10451045
pipeline = pipeline.to(accelerator.device)
10461046

10471047
# load attention processors
1048-
pipeline.load_attn_procs(args.output_dir)
1048+
pipeline.load_lora_weights(args.output_dir)
10491049

10501050
# run inference
10511051
if args.validation_prompt and args.num_validation_images > 0:

examples/test_examples.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -281,10 +281,14 @@ def test_dreambooth_lora_with_text_encoder(self):
281281
# save_pretrained smoke test
282282
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin")))
283283

284-
# the names of the keys of the state dict should either start with `unet`
285-
# or `text_encoder`.
284+
# check `text_encoder` is present at all.
286285
lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin"))
287286
keys = lora_state_dict.keys()
287+
is_text_encoder_present = any(k.startswith("text_encoder") for k in keys)
288+
self.assertTrue(is_text_encoder_present)
289+
290+
# the names of the keys of the state dict should either start with `unet`
291+
# or `text_encoder`.
288292
is_correct_naming = all(k.startswith("unet") or k.startswith("text_encoder") for k in keys)
289293
self.assertTrue(is_correct_naming)
290294

examples/text_to_image/README.md

+15
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,21 @@ image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
229229
image.save("pokemon.png")
230230
```
231231

232+
If you are loading the LoRA parameters from the Hub and if the Hub repository has
233+
a `base_model` tag (such as [this](https://huggingface.co/sayakpaul/sd-model-finetuned-lora-t4/blob/main/README.md?code=true#L4)), then
234+
you can do:
235+
236+
```py
237+
from huggingface_hub.repocard import RepoCard
238+
239+
lora_model_id = "sayakpaul/sd-model-finetuned-lora-t4"
240+
card = RepoCard.load(lora_model_id)
241+
base_model_id = card.data.to_dict()["base_model"]
242+
243+
pipe = StableDiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16)
244+
...
245+
```
246+
232247
## Training with Flax/JAX
233248

234249
For faster training on TPUs and GPUs you can leverage the flax training example. Follow the instructions above to get the model and dataset before running the script.

src/diffusers/loaders.py

+37-17
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import os
15+
import warnings
1516
from collections import defaultdict
1617
from pathlib import Path
1718
from typing import Callable, Dict, List, Optional, Union
@@ -45,6 +46,8 @@
4546

4647
logger = logging.get_logger(__name__)
4748

49+
TEXT_ENCODER_NAME = "text_encoder"
50+
UNET_NAME = "unet"
4851

4952
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
5053
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
@@ -87,6 +90,9 @@ def map_from(module, state_dict, *args, **kwargs):
8790

8891

8992
class UNet2DConditionLoadersMixin:
93+
text_encoder_name = TEXT_ENCODER_NAME
94+
unet_name = UNET_NAME
95+
9096
def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
9197
r"""
9298
Load pretrained attention processor layers into `UNet2DConditionModel`. Attention processor layers have to be
@@ -225,6 +231,18 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
225231
is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys())
226232

227233
if is_lora:
234+
is_new_lora_format = all(
235+
key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys()
236+
)
237+
if is_new_lora_format:
238+
# Strip the `"unet"` prefix.
239+
is_text_encoder_present = any(key.startswith(self.text_encoder_name) for key in state_dict.keys())
240+
if is_text_encoder_present:
241+
warn_message = "The state_dict contains LoRA params corresponding to the text encoder which are not being used here. To use both UNet and text encoder related LoRA params, use [`pipe.load_lora_weights()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.load_lora_weights)."
242+
warnings.warn(warn_message)
243+
unet_keys = [k for k in state_dict.keys() if k.startswith(self.unet_name)]
244+
state_dict = {k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys}
245+
228246
lora_grouped_dict = defaultdict(dict)
229247
for key, value in state_dict.items():
230248
attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
@@ -672,8 +690,8 @@ class LoraLoaderMixin:
672690
673691
</Tip>
674692
"""
675-
text_encoder_name = "text_encoder"
676-
unet_name = "unet"
693+
text_encoder_name = TEXT_ENCODER_NAME
694+
unet_name = UNET_NAME
677695

678696
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
679697
r"""
@@ -810,33 +828,33 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
810828
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
811829
# their prefixes.
812830
keys = list(state_dict.keys())
813-
814-
# Load the layers corresponding to UNet.
815-
if all(key.startswith(self.unet_name) for key in keys):
831+
if all(key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in keys):
832+
# Load the layers corresponding to UNet.
833+
unet_keys = [k for k in keys if k.startswith(self.unet_name)]
816834
logger.info(f"Loading {self.unet_name}.")
817-
unet_lora_state_dict = {k: v for k, v in state_dict.items() if k.startswith(self.unet_name)}
835+
unet_lora_state_dict = {
836+
k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys
837+
}
818838
self.unet.load_attn_procs(unet_lora_state_dict)
819839

820-
# Load the layers corresponding to text encoder and make necessary adjustments.
821-
elif all(key.startswith(self.text_encoder_name) for key in keys):
840+
# Load the layers corresponding to text encoder and make necessary adjustments.
841+
text_encoder_keys = [k for k in keys if k.startswith(self.text_encoder_name)]
822842
logger.info(f"Loading {self.text_encoder_name}.")
823843
text_encoder_lora_state_dict = {
824-
k: v for k, v in state_dict.items() if k.startswith(self.text_encoder_name)
844+
k.replace(f"{self.text_encoder_name}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
825845
}
826-
attn_procs_text_encoder = self.load_attn_procs(text_encoder_lora_state_dict)
827-
self._modify_text_encoder(attn_procs_text_encoder)
846+
if len(text_encoder_lora_state_dict) > 0:
847+
attn_procs_text_encoder = self._load_text_encoder_attn_procs(text_encoder_lora_state_dict)
848+
self._modify_text_encoder(attn_procs_text_encoder)
828849

829850
# Otherwise, we're dealing with the old format. This means the `state_dict` should only
830851
# contain the module names of the `unet` as its keys WITHOUT any prefix.
831852
elif not all(
832853
key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys()
833854
):
834855
self.unet.load_attn_procs(state_dict)
835-
deprecation_message = "You have saved the LoRA weights using the old format. This will be"
836-
" deprecated soon. To convert the old LoRA weights to the new format, you can first load them"
837-
" in a dictionary and then create a new dictionary like the following:"
838-
" `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`."
839-
deprecate("legacy LoRA weights", "1.0.0", deprecation_message, standard_warn=False)
856+
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`."
857+
warnings.warn(warn_message)
840858

841859
def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]):
842860
r"""
@@ -872,7 +890,9 @@ def _get_lora_layer_attribute(self, name: str) -> str:
872890
else:
873891
return "to_out_lora"
874892

875-
def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
893+
def _load_text_encoder_attn_procs(
894+
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs
895+
):
876896
r"""
877897
Load pretrained attention processor layers for
878898
[`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).

0 commit comments

Comments
 (0)