Skip to content

Commit 9bd09e1

Browse files
williambermanpatrickvonplaten
authored and
Jimmy
committed
Refactor LoRA (huggingface#3778)
* refactor to support patching LoRA into T5 instantiate the lora linear layer on the same device as the regular linear layer get lora rank from state dict tests fmt can create lora layer in float32 even when rest of model is float16 fix loading model hook remove load_lora_weights_ and T5 dispatching remove Unet#attn_processors_state_dict docstrings * text encoder monkeypatch class method * fix test --------- Co-authored-by: Patrick von Platen <[email protected]>
1 parent 201c743 commit 9bd09e1

File tree

6 files changed

+437
-377
lines changed

6 files changed

+437
-377
lines changed

examples/dreambooth/train_dreambooth_lora.py

+74-77
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import shutil
2424
import warnings
2525
from pathlib import Path
26+
from typing import Dict
2627

2728
import numpy as np
2829
import torch
@@ -50,7 +51,10 @@
5051
StableDiffusionPipeline,
5152
UNet2DConditionModel,
5253
)
53-
from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin
54+
from diffusers.loaders import (
55+
LoraLoaderMixin,
56+
text_encoder_lora_state_dict,
57+
)
5458
from diffusers.models.attention_processor import (
5559
AttnAddedKVProcessor,
5660
AttnAddedKVProcessor2_0,
@@ -60,7 +64,7 @@
6064
SlicedAttnAddedKVProcessor,
6165
)
6266
from diffusers.optimization import get_scheduler
63-
from diffusers.utils import TEXT_ENCODER_ATTN_MODULE, check_min_version, is_wandb_available
67+
from diffusers.utils import check_min_version, is_wandb_available
6468
from diffusers.utils.import_utils import is_xformers_available
6569

6670

@@ -653,6 +657,22 @@ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_atte
653657
return prompt_embeds
654658

655659

660+
def unet_attn_processors_state_dict(unet) -> Dict[str, torch.tensor]:
661+
r"""
662+
Returns:
663+
a state dict containing just the attention processor parameters.
664+
"""
665+
attn_processors = unet.attn_processors
666+
667+
attn_processors_state_dict = {}
668+
669+
for attn_processor_key, attn_processor in attn_processors.items():
670+
for parameter_key, parameter in attn_processor.state_dict().items():
671+
attn_processors_state_dict[f"{attn_processor_key}.{parameter_key}"] = parameter
672+
673+
return attn_processors_state_dict
674+
675+
656676
def main(args):
657677
logging_dir = Path(args.output_dir, args.logging_dir)
658678

@@ -833,6 +853,7 @@ def main(args):
833853

834854
# Set correct lora layers
835855
unet_lora_attn_procs = {}
856+
unet_lora_parameters = []
836857
for name, attn_processor in unet.attn_processors.items():
837858
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
838859
if name.startswith("mid_block"):
@@ -850,35 +871,18 @@ def main(args):
850871
lora_attn_processor_class = (
851872
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
852873
)
853-
unet_lora_attn_procs[name] = lora_attn_processor_class(
854-
hidden_size=hidden_size,
855-
cross_attention_dim=cross_attention_dim,
856-
rank=args.rank,
857-
)
874+
875+
module = lora_attn_processor_class(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
876+
unet_lora_attn_procs[name] = module
877+
unet_lora_parameters.extend(module.parameters())
858878

859879
unet.set_attn_processor(unet_lora_attn_procs)
860-
unet_lora_layers = AttnProcsLayers(unet.attn_processors)
861880

862881
# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
863-
# So, instead, we monkey-patch the forward calls of its attention-blocks. For this,
864-
# we first load a dummy pipeline with the text encoder and then do the monkey-patching.
865-
text_encoder_lora_layers = None
882+
# So, instead, we monkey-patch the forward calls of its attention-blocks.
866883
if args.train_text_encoder:
867-
text_lora_attn_procs = {}
868-
for name, module in text_encoder.named_modules():
869-
if name.endswith(TEXT_ENCODER_ATTN_MODULE):
870-
text_lora_attn_procs[name] = LoRAAttnProcessor(
871-
hidden_size=module.out_proj.out_features,
872-
cross_attention_dim=None,
873-
rank=args.rank,
874-
)
875-
text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs)
876-
temp_pipeline = DiffusionPipeline.from_pretrained(
877-
args.pretrained_model_name_or_path, text_encoder=text_encoder
878-
)
879-
temp_pipeline._modify_text_encoder(text_lora_attn_procs)
880-
text_encoder = temp_pipeline.text_encoder
881-
del temp_pipeline
884+
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
885+
text_lora_parameters = LoraLoaderMixin._modify_text_encoder(text_encoder, dtype=torch.float32)
882886

883887
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
884888
def save_model_hook(models, weights, output_dir):
@@ -887,23 +891,13 @@ def save_model_hook(models, weights, output_dir):
887891
unet_lora_layers_to_save = None
888892
text_encoder_lora_layers_to_save = None
889893

890-
if args.train_text_encoder:
891-
text_encoder_keys = accelerator.unwrap_model(text_encoder_lora_layers).state_dict().keys()
892-
unet_keys = accelerator.unwrap_model(unet_lora_layers).state_dict().keys()
893-
894894
for model in models:
895-
state_dict = model.state_dict()
896-
897-
if (
898-
text_encoder_lora_layers is not None
899-
and text_encoder_keys is not None
900-
and state_dict.keys() == text_encoder_keys
901-
):
902-
# text encoder
903-
text_encoder_lora_layers_to_save = state_dict
904-
elif state_dict.keys() == unet_keys:
905-
# unet
906-
unet_lora_layers_to_save = state_dict
895+
if isinstance(model, type(accelerator.unwrap_model(unet))):
896+
unet_lora_layers_to_save = unet_attn_processors_state_dict(model)
897+
elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
898+
text_encoder_lora_layers_to_save = text_encoder_lora_state_dict(model)
899+
else:
900+
raise ValueError(f"unexpected save model: {model.__class__}")
907901

908902
# make sure to pop weight so that corresponding model is not saved again
909903
weights.pop()
@@ -915,27 +909,24 @@ def save_model_hook(models, weights, output_dir):
915909
)
916910

917911
def load_model_hook(models, input_dir):
918-
# Note we DON'T pass the unet and text encoder here an purpose
919-
# so that the we don't accidentally override the LoRA layers of
920-
# unet_lora_layers and text_encoder_lora_layers which are stored in `models`
921-
# with new torch.nn.Modules / weights. We simply use the pipeline class as
922-
# an easy way to load the lora checkpoints
923-
temp_pipeline = DiffusionPipeline.from_pretrained(
924-
args.pretrained_model_name_or_path,
925-
revision=args.revision,
926-
torch_dtype=weight_dtype,
927-
)
928-
temp_pipeline.load_lora_weights(input_dir)
912+
unet_ = None
913+
text_encoder_ = None
929914

930-
# load lora weights into models
931-
models[0].load_state_dict(AttnProcsLayers(temp_pipeline.unet.attn_processors).state_dict())
932-
if len(models) > 1:
933-
models[1].load_state_dict(AttnProcsLayers(temp_pipeline.text_encoder_lora_attn_procs).state_dict())
915+
while len(models) > 0:
916+
model = models.pop()
934917

935-
# delete temporary pipeline and pop models
936-
del temp_pipeline
937-
for _ in range(len(models)):
938-
models.pop()
918+
if isinstance(model, type(accelerator.unwrap_model(unet))):
919+
unet_ = model
920+
elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
921+
text_encoder_ = model
922+
else:
923+
raise ValueError(f"unexpected save model: {model.__class__}")
924+
925+
lora_state_dict, network_alpha = LoraLoaderMixin.lora_state_dict(input_dir)
926+
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alpha=network_alpha, unet=unet_)
927+
LoraLoaderMixin.load_lora_into_text_encoder(
928+
lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_
929+
)
939930

940931
accelerator.register_save_state_pre_hook(save_model_hook)
941932
accelerator.register_load_state_pre_hook(load_model_hook)
@@ -965,9 +956,9 @@ def load_model_hook(models, input_dir):
965956

966957
# Optimizer creation
967958
params_to_optimize = (
968-
itertools.chain(unet_lora_layers.parameters(), text_encoder_lora_layers.parameters())
959+
itertools.chain(unet_lora_parameters, text_lora_parameters)
969960
if args.train_text_encoder
970-
else unet_lora_layers.parameters()
961+
else unet_lora_parameters
971962
)
972963
optimizer = optimizer_class(
973964
params_to_optimize,
@@ -1056,12 +1047,12 @@ def compute_text_embeddings(prompt):
10561047

10571048
# Prepare everything with our `accelerator`.
10581049
if args.train_text_encoder:
1059-
unet_lora_layers, text_encoder_lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
1060-
unet_lora_layers, text_encoder_lora_layers, optimizer, train_dataloader, lr_scheduler
1050+
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
1051+
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
10611052
)
10621053
else:
1063-
unet_lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
1064-
unet_lora_layers, optimizer, train_dataloader, lr_scheduler
1054+
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
1055+
unet, optimizer, train_dataloader, lr_scheduler
10651056
)
10661057

10671058
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
@@ -1210,9 +1201,9 @@ def compute_text_embeddings(prompt):
12101201
accelerator.backward(loss)
12111202
if accelerator.sync_gradients:
12121203
params_to_clip = (
1213-
itertools.chain(unet_lora_layers.parameters(), text_encoder_lora_layers.parameters())
1204+
itertools.chain(unet_lora_parameters, text_lora_parameters)
12141205
if args.train_text_encoder
1215-
else unet_lora_layers.parameters()
1206+
else unet_lora_parameters
12161207
)
12171208
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
12181209
optimizer.step()
@@ -1301,15 +1292,17 @@ def compute_text_embeddings(prompt):
13011292
pipeline_args = {"prompt": args.validation_prompt}
13021293

13031294
if args.validation_images is None:
1304-
images = [
1305-
pipeline(**pipeline_args, generator=generator).images[0]
1306-
for _ in range(args.num_validation_images)
1307-
]
1295+
images = []
1296+
for _ in range(args.num_validation_images):
1297+
with torch.cuda.amp.autocast():
1298+
image = pipeline(**pipeline_args, generator=generator).images[0]
1299+
images.append(image)
13081300
else:
13091301
images = []
13101302
for image in args.validation_images:
13111303
image = Image.open(image)
1312-
image = pipeline(**pipeline_args, image=image, generator=generator).images[0]
1304+
with torch.cuda.amp.autocast():
1305+
image = pipeline(**pipeline_args, image=image, generator=generator).images[0]
13131306
images.append(image)
13141307

13151308
for tracker in accelerator.trackers:
@@ -1332,12 +1325,16 @@ def compute_text_embeddings(prompt):
13321325
# Save the lora layers
13331326
accelerator.wait_for_everyone()
13341327
if accelerator.is_main_process:
1328+
unet = accelerator.unwrap_model(unet)
13351329
unet = unet.to(torch.float32)
1336-
unet_lora_layers = accelerator.unwrap_model(unet_lora_layers)
1330+
unet_lora_layers = unet_attn_processors_state_dict(unet)
13371331

1338-
if text_encoder is not None:
1332+
if text_encoder is not None and args.train_text_encoder:
1333+
text_encoder = accelerator.unwrap_model(text_encoder)
13391334
text_encoder = text_encoder.to(torch.float32)
1340-
text_encoder_lora_layers = accelerator.unwrap_model(text_encoder_lora_layers)
1335+
text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder)
1336+
else:
1337+
text_encoder_lora_layers = None
13411338

13421339
LoraLoaderMixin.save_lora_weights(
13431340
save_directory=args.output_dir,

0 commit comments

Comments
 (0)