Skip to content

Refactor LoRA #3778

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 74 additions & 77 deletions examples/dreambooth/train_dreambooth_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import shutil
import warnings
from pathlib import Path
from typing import Dict

import numpy as np
import torch
Expand Down Expand Up @@ -50,7 +51,10 @@
StableDiffusionPipeline,
UNet2DConditionModel,
)
from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin
from diffusers.loaders import (
LoraLoaderMixin,
text_encoder_lora_state_dict,
)
from diffusers.models.attention_processor import (
AttnAddedKVProcessor,
AttnAddedKVProcessor2_0,
Expand All @@ -60,7 +64,7 @@
SlicedAttnAddedKVProcessor,
)
from diffusers.optimization import get_scheduler
from diffusers.utils import TEXT_ENCODER_ATTN_MODULE, check_min_version, is_wandb_available
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available


Expand Down Expand Up @@ -653,6 +657,22 @@ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_atte
return prompt_embeds


def unet_attn_processors_state_dict(unet) -> Dict[str, torch.tensor]:
r"""
Returns:
a state dict containing just the attention processor parameters.
"""
attn_processors = unet.attn_processors

attn_processors_state_dict = {}

for attn_processor_key, attn_processor in attn_processors.items():
for parameter_key, parameter in attn_processor.state_dict().items():
attn_processors_state_dict[f"{attn_processor_key}.{parameter_key}"] = parameter

return attn_processors_state_dict


def main(args):
logging_dir = Path(args.output_dir, args.logging_dir)

Expand Down Expand Up @@ -833,6 +853,7 @@ def main(args):

# Set correct lora layers
unet_lora_attn_procs = {}
unet_lora_parameters = []
for name, attn_processor in unet.attn_processors.items():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
Expand All @@ -850,35 +871,18 @@ def main(args):
lora_attn_processor_class = (
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
)
unet_lora_attn_procs[name] = lora_attn_processor_class(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
rank=args.rank,
)

module = lora_attn_processor_class(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
unet_lora_attn_procs[name] = module
unet_lora_parameters.extend(module.parameters())

unet.set_attn_processor(unet_lora_attn_procs)
unet_lora_layers = AttnProcsLayers(unet.attn_processors)

# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
# So, instead, we monkey-patch the forward calls of its attention-blocks. For this,
# we first load a dummy pipeline with the text encoder and then do the monkey-patching.
text_encoder_lora_layers = None
# So, instead, we monkey-patch the forward calls of its attention-blocks.
if args.train_text_encoder:
text_lora_attn_procs = {}
for name, module in text_encoder.named_modules():
if name.endswith(TEXT_ENCODER_ATTN_MODULE):
text_lora_attn_procs[name] = LoRAAttnProcessor(
hidden_size=module.out_proj.out_features,
cross_attention_dim=None,
rank=args.rank,
)
text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs)
temp_pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, text_encoder=text_encoder
)
temp_pipeline._modify_text_encoder(text_lora_attn_procs)
text_encoder = temp_pipeline.text_encoder
del temp_pipeline
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
text_lora_parameters = LoraLoaderMixin._modify_text_encoder(text_encoder, dtype=torch.float32)

# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
Expand All @@ -887,23 +891,13 @@ def save_model_hook(models, weights, output_dir):
unet_lora_layers_to_save = None
text_encoder_lora_layers_to_save = None

if args.train_text_encoder:
text_encoder_keys = accelerator.unwrap_model(text_encoder_lora_layers).state_dict().keys()
unet_keys = accelerator.unwrap_model(unet_lora_layers).state_dict().keys()

for model in models:
state_dict = model.state_dict()

if (
text_encoder_lora_layers is not None
and text_encoder_keys is not None
and state_dict.keys() == text_encoder_keys
):
# text encoder
text_encoder_lora_layers_to_save = state_dict
elif state_dict.keys() == unet_keys:
# unet
unet_lora_layers_to_save = state_dict
if isinstance(model, type(accelerator.unwrap_model(unet))):
unet_lora_layers_to_save = unet_attn_processors_state_dict(model)
elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
text_encoder_lora_layers_to_save = text_encoder_lora_state_dict(model)
else:
raise ValueError(f"unexpected save model: {model.__class__}")

# make sure to pop weight so that corresponding model is not saved again
weights.pop()
Expand All @@ -915,27 +909,24 @@ def save_model_hook(models, weights, output_dir):
)

def load_model_hook(models, input_dir):
# Note we DON'T pass the unet and text encoder here an purpose
# so that the we don't accidentally override the LoRA layers of
# unet_lora_layers and text_encoder_lora_layers which are stored in `models`
# with new torch.nn.Modules / weights. We simply use the pipeline class as
# an easy way to load the lora checkpoints
temp_pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
revision=args.revision,
torch_dtype=weight_dtype,
)
temp_pipeline.load_lora_weights(input_dir)
unet_ = None
text_encoder_ = None

# load lora weights into models
models[0].load_state_dict(AttnProcsLayers(temp_pipeline.unet.attn_processors).state_dict())
if len(models) > 1:
models[1].load_state_dict(AttnProcsLayers(temp_pipeline.text_encoder_lora_attn_procs).state_dict())
while len(models) > 0:
model = models.pop()

# delete temporary pipeline and pop models
del temp_pipeline
for _ in range(len(models)):
models.pop()
if isinstance(model, type(accelerator.unwrap_model(unet))):
unet_ = model
elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
text_encoder_ = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")

lora_state_dict, network_alpha = LoraLoaderMixin.lora_state_dict(input_dir)
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alpha=network_alpha, unet=unet_)
LoraLoaderMixin.load_lora_into_text_encoder(
lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_
)

accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
Expand Down Expand Up @@ -965,9 +956,9 @@ def load_model_hook(models, input_dir):

# Optimizer creation
params_to_optimize = (
itertools.chain(unet_lora_layers.parameters(), text_encoder_lora_layers.parameters())
itertools.chain(unet_lora_parameters, text_lora_parameters)
if args.train_text_encoder
else unet_lora_layers.parameters()
else unet_lora_parameters
)
optimizer = optimizer_class(
params_to_optimize,
Expand Down Expand Up @@ -1056,12 +1047,12 @@ def compute_text_embeddings(prompt):

# Prepare everything with our `accelerator`.
if args.train_text_encoder:
unet_lora_layers, text_encoder_lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet_lora_layers, text_encoder_lora_layers, optimizer, train_dataloader, lr_scheduler
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
else:
unet_lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet_lora_layers, optimizer, train_dataloader, lr_scheduler
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, optimizer, train_dataloader, lr_scheduler
)

# We need to recalculate our total training steps as the size of the training dataloader may have changed.
Expand Down Expand Up @@ -1210,9 +1201,9 @@ def compute_text_embeddings(prompt):
accelerator.backward(loss)
if accelerator.sync_gradients:
params_to_clip = (
itertools.chain(unet_lora_layers.parameters(), text_encoder_lora_layers.parameters())
itertools.chain(unet_lora_parameters, text_lora_parameters)
if args.train_text_encoder
else unet_lora_layers.parameters()
else unet_lora_parameters
)
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
Expand Down Expand Up @@ -1301,15 +1292,17 @@ def compute_text_embeddings(prompt):
pipeline_args = {"prompt": args.validation_prompt}

if args.validation_images is None:
images = [
pipeline(**pipeline_args, generator=generator).images[0]
for _ in range(args.num_validation_images)
]
images = []
for _ in range(args.num_validation_images):
with torch.cuda.amp.autocast():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why these changes? We try to avoid running pipelines in autocast if possible

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I copy and pasted from the regular dreambooth training script which runs the validation inference under autocast. Will remove

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I might have added this because when we load the rest of the model in fp16, we keep the lora weights in fp32 and needed autocast for them to work together

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, looking through the commit history, that's why I added it. Is that ok?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For running inference validation in intervals, I think keeping autocast is okay as it helps keep things simple.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why didn't we need it before?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok so I was wrong that we needed it because of the difference in dtype in the lora weights. The lora layer casts manually internally.

The issue was the dtype of the output of the unet being passed to the vae.

The difference is that in the main version of the script, the unet is not wrapped in amp, so the output of the unet during validation is the same dtype as the unet, fp16.

In the branch, the full unet is wrapped in amp so even though the unet is loaded in fp16, the output is fp32 and then there's an error when the fp32 ending latents are passed to the fp16 vae.

Copy link
Contributor Author

@williamberman williamberman Jul 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The alternative to wrapping this section in amp is to put a check either in the pipeline or the beginning of the vae for the dtype of the initial latents and manually cast them if necessary. I like to avoid manual casts like that in case the caller expects the execution in the dtype of their input (which is a reasonable assumption imo). So I would prefer to leave the amp decorator in this case

image = pipeline(**pipeline_args, generator=generator).images[0]
images.append(image)
else:
images = []
for image in args.validation_images:
image = Image.open(image)
image = pipeline(**pipeline_args, image=image, generator=generator).images[0]
with torch.cuda.amp.autocast():
image = pipeline(**pipeline_args, image=image, generator=generator).images[0]
images.append(image)

for tracker in accelerator.trackers:
Expand All @@ -1332,12 +1325,16 @@ def compute_text_embeddings(prompt):
# Save the lora layers
accelerator.wait_for_everyone()
if accelerator.is_main_process:
unet = accelerator.unwrap_model(unet)
unet = unet.to(torch.float32)
unet_lora_layers = accelerator.unwrap_model(unet_lora_layers)
unet_lora_layers = unet_attn_processors_state_dict(unet)

if text_encoder is not None:
if text_encoder is not None and args.train_text_encoder:
text_encoder = accelerator.unwrap_model(text_encoder)
text_encoder = text_encoder.to(torch.float32)
text_encoder_lora_layers = accelerator.unwrap_model(text_encoder_lora_layers)
text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder)
else:
text_encoder_lora_layers = None

LoraLoaderMixin.save_lora_weights(
save_directory=args.output_dir,
Expand Down
Loading