23
23
import shutil
24
24
import warnings
25
25
from pathlib import Path
26
+ from typing import Dict
26
27
27
28
import numpy as np
28
29
import torch
50
51
StableDiffusionPipeline ,
51
52
UNet2DConditionModel ,
52
53
)
53
- from diffusers .loaders import AttnProcsLayers , LoraLoaderMixin
54
+ from diffusers .loaders import (
55
+ LoraLoaderMixin ,
56
+ text_encoder_lora_state_dict ,
57
+ )
54
58
from diffusers .models .attention_processor import (
55
59
AttnAddedKVProcessor ,
56
60
AttnAddedKVProcessor2_0 ,
60
64
SlicedAttnAddedKVProcessor ,
61
65
)
62
66
from diffusers .optimization import get_scheduler
63
- from diffusers .utils import TEXT_ENCODER_ATTN_MODULE , check_min_version , is_wandb_available
67
+ from diffusers .utils import check_min_version , is_wandb_available
64
68
from diffusers .utils .import_utils import is_xformers_available
65
69
66
70
@@ -653,6 +657,22 @@ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_atte
653
657
return prompt_embeds
654
658
655
659
660
+ def unet_attn_processors_state_dict (unet ) -> Dict [str , torch .tensor ]:
661
+ r"""
662
+ Returns:
663
+ a state dict containing just the attention processor parameters.
664
+ """
665
+ attn_processors = unet .attn_processors
666
+
667
+ attn_processors_state_dict = {}
668
+
669
+ for attn_processor_key , attn_processor in attn_processors .items ():
670
+ for parameter_key , parameter in attn_processor .state_dict ().items ():
671
+ attn_processors_state_dict [f"{ attn_processor_key } .{ parameter_key } " ] = parameter
672
+
673
+ return attn_processors_state_dict
674
+
675
+
656
676
def main (args ):
657
677
logging_dir = Path (args .output_dir , args .logging_dir )
658
678
@@ -833,6 +853,7 @@ def main(args):
833
853
834
854
# Set correct lora layers
835
855
unet_lora_attn_procs = {}
856
+ unet_lora_parameters = []
836
857
for name , attn_processor in unet .attn_processors .items ():
837
858
cross_attention_dim = None if name .endswith ("attn1.processor" ) else unet .config .cross_attention_dim
838
859
if name .startswith ("mid_block" ):
@@ -850,35 +871,18 @@ def main(args):
850
871
lora_attn_processor_class = (
851
872
LoRAAttnProcessor2_0 if hasattr (F , "scaled_dot_product_attention" ) else LoRAAttnProcessor
852
873
)
853
- unet_lora_attn_procs [name ] = lora_attn_processor_class (
854
- hidden_size = hidden_size ,
855
- cross_attention_dim = cross_attention_dim ,
856
- rank = args .rank ,
857
- )
874
+
875
+ module = lora_attn_processor_class (hidden_size = hidden_size , cross_attention_dim = cross_attention_dim )
876
+ unet_lora_attn_procs [name ] = module
877
+ unet_lora_parameters .extend (module .parameters ())
858
878
859
879
unet .set_attn_processor (unet_lora_attn_procs )
860
- unet_lora_layers = AttnProcsLayers (unet .attn_processors )
861
880
862
881
# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
863
- # So, instead, we monkey-patch the forward calls of its attention-blocks. For this,
864
- # we first load a dummy pipeline with the text encoder and then do the monkey-patching.
865
- text_encoder_lora_layers = None
882
+ # So, instead, we monkey-patch the forward calls of its attention-blocks.
866
883
if args .train_text_encoder :
867
- text_lora_attn_procs = {}
868
- for name , module in text_encoder .named_modules ():
869
- if name .endswith (TEXT_ENCODER_ATTN_MODULE ):
870
- text_lora_attn_procs [name ] = LoRAAttnProcessor (
871
- hidden_size = module .out_proj .out_features ,
872
- cross_attention_dim = None ,
873
- rank = args .rank ,
874
- )
875
- text_encoder_lora_layers = AttnProcsLayers (text_lora_attn_procs )
876
- temp_pipeline = DiffusionPipeline .from_pretrained (
877
- args .pretrained_model_name_or_path , text_encoder = text_encoder
878
- )
879
- temp_pipeline ._modify_text_encoder (text_lora_attn_procs )
880
- text_encoder = temp_pipeline .text_encoder
881
- del temp_pipeline
884
+ # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
885
+ text_lora_parameters = LoraLoaderMixin ._modify_text_encoder (text_encoder , dtype = torch .float32 )
882
886
883
887
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
884
888
def save_model_hook (models , weights , output_dir ):
@@ -887,23 +891,13 @@ def save_model_hook(models, weights, output_dir):
887
891
unet_lora_layers_to_save = None
888
892
text_encoder_lora_layers_to_save = None
889
893
890
- if args .train_text_encoder :
891
- text_encoder_keys = accelerator .unwrap_model (text_encoder_lora_layers ).state_dict ().keys ()
892
- unet_keys = accelerator .unwrap_model (unet_lora_layers ).state_dict ().keys ()
893
-
894
894
for model in models :
895
- state_dict = model .state_dict ()
896
-
897
- if (
898
- text_encoder_lora_layers is not None
899
- and text_encoder_keys is not None
900
- and state_dict .keys () == text_encoder_keys
901
- ):
902
- # text encoder
903
- text_encoder_lora_layers_to_save = state_dict
904
- elif state_dict .keys () == unet_keys :
905
- # unet
906
- unet_lora_layers_to_save = state_dict
895
+ if isinstance (model , type (accelerator .unwrap_model (unet ))):
896
+ unet_lora_layers_to_save = unet_attn_processors_state_dict (model )
897
+ elif isinstance (model , type (accelerator .unwrap_model (text_encoder ))):
898
+ text_encoder_lora_layers_to_save = text_encoder_lora_state_dict (model )
899
+ else :
900
+ raise ValueError (f"unexpected save model: { model .__class__ } " )
907
901
908
902
# make sure to pop weight so that corresponding model is not saved again
909
903
weights .pop ()
@@ -915,27 +909,24 @@ def save_model_hook(models, weights, output_dir):
915
909
)
916
910
917
911
def load_model_hook (models , input_dir ):
918
- # Note we DON'T pass the unet and text encoder here an purpose
919
- # so that the we don't accidentally override the LoRA layers of
920
- # unet_lora_layers and text_encoder_lora_layers which are stored in `models`
921
- # with new torch.nn.Modules / weights. We simply use the pipeline class as
922
- # an easy way to load the lora checkpoints
923
- temp_pipeline = DiffusionPipeline .from_pretrained (
924
- args .pretrained_model_name_or_path ,
925
- revision = args .revision ,
926
- torch_dtype = weight_dtype ,
927
- )
928
- temp_pipeline .load_lora_weights (input_dir )
912
+ unet_ = None
913
+ text_encoder_ = None
929
914
930
- # load lora weights into models
931
- models [0 ].load_state_dict (AttnProcsLayers (temp_pipeline .unet .attn_processors ).state_dict ())
932
- if len (models ) > 1 :
933
- models [1 ].load_state_dict (AttnProcsLayers (temp_pipeline .text_encoder_lora_attn_procs ).state_dict ())
915
+ while len (models ) > 0 :
916
+ model = models .pop ()
934
917
935
- # delete temporary pipeline and pop models
936
- del temp_pipeline
937
- for _ in range (len (models )):
938
- models .pop ()
918
+ if isinstance (model , type (accelerator .unwrap_model (unet ))):
919
+ unet_ = model
920
+ elif isinstance (model , type (accelerator .unwrap_model (text_encoder ))):
921
+ text_encoder_ = model
922
+ else :
923
+ raise ValueError (f"unexpected save model: { model .__class__ } " )
924
+
925
+ lora_state_dict , network_alpha = LoraLoaderMixin .lora_state_dict (input_dir )
926
+ LoraLoaderMixin .load_lora_into_unet (lora_state_dict , network_alpha = network_alpha , unet = unet_ )
927
+ LoraLoaderMixin .load_lora_into_text_encoder (
928
+ lora_state_dict , network_alpha = network_alpha , text_encoder = text_encoder_
929
+ )
939
930
940
931
accelerator .register_save_state_pre_hook (save_model_hook )
941
932
accelerator .register_load_state_pre_hook (load_model_hook )
@@ -965,9 +956,9 @@ def load_model_hook(models, input_dir):
965
956
966
957
# Optimizer creation
967
958
params_to_optimize = (
968
- itertools .chain (unet_lora_layers . parameters (), text_encoder_lora_layers . parameters () )
959
+ itertools .chain (unet_lora_parameters , text_lora_parameters )
969
960
if args .train_text_encoder
970
- else unet_lora_layers . parameters ()
961
+ else unet_lora_parameters
971
962
)
972
963
optimizer = optimizer_class (
973
964
params_to_optimize ,
@@ -1056,12 +1047,12 @@ def compute_text_embeddings(prompt):
1056
1047
1057
1048
# Prepare everything with our `accelerator`.
1058
1049
if args .train_text_encoder :
1059
- unet_lora_layers , text_encoder_lora_layers , optimizer , train_dataloader , lr_scheduler = accelerator .prepare (
1060
- unet_lora_layers , text_encoder_lora_layers , optimizer , train_dataloader , lr_scheduler
1050
+ unet , text_encoder , optimizer , train_dataloader , lr_scheduler = accelerator .prepare (
1051
+ unet , text_encoder , optimizer , train_dataloader , lr_scheduler
1061
1052
)
1062
1053
else :
1063
- unet_lora_layers , optimizer , train_dataloader , lr_scheduler = accelerator .prepare (
1064
- unet_lora_layers , optimizer , train_dataloader , lr_scheduler
1054
+ unet , optimizer , train_dataloader , lr_scheduler = accelerator .prepare (
1055
+ unet , optimizer , train_dataloader , lr_scheduler
1065
1056
)
1066
1057
1067
1058
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
@@ -1210,9 +1201,9 @@ def compute_text_embeddings(prompt):
1210
1201
accelerator .backward (loss )
1211
1202
if accelerator .sync_gradients :
1212
1203
params_to_clip = (
1213
- itertools .chain (unet_lora_layers . parameters (), text_encoder_lora_layers . parameters () )
1204
+ itertools .chain (unet_lora_parameters , text_lora_parameters )
1214
1205
if args .train_text_encoder
1215
- else unet_lora_layers . parameters ()
1206
+ else unet_lora_parameters
1216
1207
)
1217
1208
accelerator .clip_grad_norm_ (params_to_clip , args .max_grad_norm )
1218
1209
optimizer .step ()
@@ -1301,15 +1292,17 @@ def compute_text_embeddings(prompt):
1301
1292
pipeline_args = {"prompt" : args .validation_prompt }
1302
1293
1303
1294
if args .validation_images is None :
1304
- images = [
1305
- pipeline (** pipeline_args , generator = generator ).images [0 ]
1306
- for _ in range (args .num_validation_images )
1307
- ]
1295
+ images = []
1296
+ for _ in range (args .num_validation_images ):
1297
+ with torch .cuda .amp .autocast ():
1298
+ image = pipeline (** pipeline_args , generator = generator ).images [0 ]
1299
+ images .append (image )
1308
1300
else :
1309
1301
images = []
1310
1302
for image in args .validation_images :
1311
1303
image = Image .open (image )
1312
- image = pipeline (** pipeline_args , image = image , generator = generator ).images [0 ]
1304
+ with torch .cuda .amp .autocast ():
1305
+ image = pipeline (** pipeline_args , image = image , generator = generator ).images [0 ]
1313
1306
images .append (image )
1314
1307
1315
1308
for tracker in accelerator .trackers :
@@ -1332,12 +1325,16 @@ def compute_text_embeddings(prompt):
1332
1325
# Save the lora layers
1333
1326
accelerator .wait_for_everyone ()
1334
1327
if accelerator .is_main_process :
1328
+ unet = accelerator .unwrap_model (unet )
1335
1329
unet = unet .to (torch .float32 )
1336
- unet_lora_layers = accelerator . unwrap_model ( unet_lora_layers )
1330
+ unet_lora_layers = unet_attn_processors_state_dict ( unet )
1337
1331
1338
- if text_encoder is not None :
1332
+ if text_encoder is not None and args .train_text_encoder :
1333
+ text_encoder = accelerator .unwrap_model (text_encoder )
1339
1334
text_encoder = text_encoder .to (torch .float32 )
1340
- text_encoder_lora_layers = accelerator .unwrap_model (text_encoder_lora_layers )
1335
+ text_encoder_lora_layers = text_encoder_lora_state_dict (text_encoder )
1336
+ else :
1337
+ text_encoder_lora_layers = None
1341
1338
1342
1339
LoraLoaderMixin .save_lora_weights (
1343
1340
save_directory = args .output_dir ,
0 commit comments