Skip to content

Commit dbea93c

Browse files
[Advanced LoRA v1.5] fix: gradient unscaling problem (#7018)
fix: gradient unscaling problem Co-authored-by: Linoy Tsaban <[email protected]>
1 parent dd3e554 commit dbea93c

File tree

1 file changed

+34
-2
lines changed

1 file changed

+34
-2
lines changed

examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py

+34-2
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
4040
from huggingface_hub import create_repo, upload_folder
4141
from packaging import version
42-
from peft import LoraConfig
42+
from peft import LoraConfig, set_peft_model_state_dict
4343
from peft.utils import get_peft_model_state_dict
4444
from PIL import Image
4545
from PIL.ImageOps import exif_transpose
@@ -59,12 +59,13 @@
5959
)
6060
from diffusers.loaders import StableDiffusionLoraLoaderMixin
6161
from diffusers.optimization import get_scheduler
62-
from diffusers.training_utils import compute_snr
62+
from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params, compute_snr
6363
from diffusers.utils import (
6464
check_min_version,
6565
convert_all_state_dict_to_peft,
6666
convert_state_dict_to_diffusers,
6767
convert_state_dict_to_kohya,
68+
convert_unet_state_dict_to_peft,
6869
is_wandb_available,
6970
)
7071
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
@@ -1319,6 +1320,37 @@ def load_model_hook(models, input_dir):
13191320
else:
13201321
raise ValueError(f"unexpected save model: {model.__class__}")
13211322

1323+
lora_state_dict, network_alphas = StableDiffusionPipeline.lora_state_dict(input_dir)
1324+
1325+
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
1326+
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
1327+
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
1328+
if incompatible_keys is not None:
1329+
# check only for unexpected keys
1330+
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
1331+
if unexpected_keys:
1332+
logger.warning(
1333+
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
1334+
f" {unexpected_keys}. "
1335+
)
1336+
1337+
if args.train_text_encoder:
1338+
# Do we need to call `scale_lora_layers()` here?
1339+
_set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_)
1340+
1341+
_set_state_dict_into_text_encoder(
1342+
lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_one_
1343+
)
1344+
1345+
# Make sure the trainable params are in float32. This is again needed since the base models
1346+
# are in `weight_dtype`. More details:
1347+
# https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
1348+
if args.mixed_precision == "fp16":
1349+
models = [unet_]
1350+
if args.train_text_encoder:
1351+
models.extend([text_encoder_one_])
1352+
# only upcast trainable parameters (LoRA) into fp32
1353+
cast_training_params(models)
13221354
lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
13231355
StableDiffusionLoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
13241356

0 commit comments

Comments
 (0)