|
39 | 39 | from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
|
40 | 40 | from huggingface_hub import create_repo, upload_folder
|
41 | 41 | from packaging import version
|
42 |
| -from peft import LoraConfig |
| 42 | +from peft import LoraConfig, set_peft_model_state_dict |
43 | 43 | from peft.utils import get_peft_model_state_dict
|
44 | 44 | from PIL import Image
|
45 | 45 | from PIL.ImageOps import exif_transpose
|
|
59 | 59 | )
|
60 | 60 | from diffusers.loaders import StableDiffusionLoraLoaderMixin
|
61 | 61 | from diffusers.optimization import get_scheduler
|
62 |
| -from diffusers.training_utils import compute_snr |
| 62 | +from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params, compute_snr |
63 | 63 | from diffusers.utils import (
|
64 | 64 | check_min_version,
|
65 | 65 | convert_all_state_dict_to_peft,
|
66 | 66 | convert_state_dict_to_diffusers,
|
67 | 67 | convert_state_dict_to_kohya,
|
| 68 | + convert_unet_state_dict_to_peft, |
68 | 69 | is_wandb_available,
|
69 | 70 | )
|
70 | 71 | from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
@@ -1319,6 +1320,37 @@ def load_model_hook(models, input_dir):
|
1319 | 1320 | else:
|
1320 | 1321 | raise ValueError(f"unexpected save model: {model.__class__}")
|
1321 | 1322 |
|
| 1323 | + lora_state_dict, network_alphas = StableDiffusionPipeline.lora_state_dict(input_dir) |
| 1324 | + |
| 1325 | + unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} |
| 1326 | + unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) |
| 1327 | + incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") |
| 1328 | + if incompatible_keys is not None: |
| 1329 | + # check only for unexpected keys |
| 1330 | + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) |
| 1331 | + if unexpected_keys: |
| 1332 | + logger.warning( |
| 1333 | + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " |
| 1334 | + f" {unexpected_keys}. " |
| 1335 | + ) |
| 1336 | + |
| 1337 | + if args.train_text_encoder: |
| 1338 | + # Do we need to call `scale_lora_layers()` here? |
| 1339 | + _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_) |
| 1340 | + |
| 1341 | + _set_state_dict_into_text_encoder( |
| 1342 | + lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_one_ |
| 1343 | + ) |
| 1344 | + |
| 1345 | + # Make sure the trainable params are in float32. This is again needed since the base models |
| 1346 | + # are in `weight_dtype`. More details: |
| 1347 | + # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 |
| 1348 | + if args.mixed_precision == "fp16": |
| 1349 | + models = [unet_] |
| 1350 | + if args.train_text_encoder: |
| 1351 | + models.extend([text_encoder_one_]) |
| 1352 | + # only upcast trainable parameters (LoRA) into fp32 |
| 1353 | + cast_training_params(models) |
1322 | 1354 | lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
|
1323 | 1355 | StableDiffusionLoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
|
1324 | 1356 |
|
|
0 commit comments