@@ -120,18 +120,18 @@ def text_encoder_attn_modules(text_encoder):
120
120
return attn_modules
121
121
122
122
123
- def text_encoder_aux_modules (text_encoder ):
124
- aux_modules = []
123
+ def text_encoder_mlp_modules (text_encoder ):
124
+ mlp_modules = []
125
125
126
126
if isinstance (text_encoder , CLIPTextModel ):
127
127
for i , layer in enumerate (text_encoder .text_model .encoder .layers ):
128
128
mlp_mod = layer .mlp
129
129
name = f"text_model.encoder.layers.{ i } .mlp"
130
- aux_modules .append ((name , mlp_mod ))
130
+ mlp_modules .append ((name , mlp_mod ))
131
131
else :
132
- raise ValueError (f"do not know how to get aux modules for: { text_encoder .__class__ .__name__ } " )
132
+ raise ValueError (f"do not know how to get mlp modules for: { text_encoder .__class__ .__name__ } " )
133
133
134
- return aux_modules
134
+ return mlp_modules
135
135
136
136
137
137
def text_encoder_lora_state_dict (text_encoder ):
@@ -322,6 +322,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
322
322
323
323
# fill attn processors
324
324
attn_processors = {}
325
+ ff_layers = []
325
326
326
327
is_lora = all ("lora" in k for k in state_dict .keys ())
327
328
is_custom_diffusion = any ("custom_diffusion" in k for k in state_dict .keys ())
@@ -345,13 +346,32 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
345
346
lora_grouped_dict [attn_processor_key ][sub_key ] = value
346
347
347
348
for key , value_dict in lora_grouped_dict .items ():
348
- rank = value_dict ["to_k_lora.down.weight" ].shape [0 ]
349
- hidden_size = value_dict ["to_k_lora.up.weight" ].shape [0 ]
350
-
351
349
attn_processor = self
352
350
for sub_key in key .split ("." ):
353
351
attn_processor = getattr (attn_processor , sub_key )
354
352
353
+ # Process FF layers
354
+ if "lora.down.weight" in value_dict :
355
+ rank = value_dict ["lora.down.weight" ].shape [0 ]
356
+ hidden_size = value_dict ["lora.up.weight" ].shape [0 ]
357
+
358
+ if isinstance (attn_processor , LoRACompatibleConv ):
359
+ lora = LoRAConv2dLayer (hidden_size , hidden_size , rank , network_alpha )
360
+ elif isinstance (attn_processor , LoRACompatibleLinear ):
361
+ lora = LoRALinearLayer (
362
+ attn_processor .in_features , attn_processor .out_features , rank , network_alpha
363
+ )
364
+ else :
365
+ raise ValueError (f"Module { key } is not a LoRACompatibleConv or LoRACompatibleLinear module." )
366
+
367
+ value_dict = {k .replace ("lora." , "" ): v for k , v in value_dict .items ()}
368
+ lora .load_state_dict (value_dict )
369
+ ff_layers .append ((attn_processor , lora ))
370
+ continue
371
+
372
+ rank = value_dict ["to_k_lora.down.weight" ].shape [0 ]
373
+ hidden_size = value_dict ["to_k_lora.up.weight" ].shape [0 ]
374
+
355
375
if isinstance (
356
376
attn_processor , (AttnAddedKVProcessor , SlicedAttnAddedKVProcessor , AttnAddedKVProcessor2_0 )
357
377
):
@@ -408,10 +428,16 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
408
428
409
429
# set correct dtype & device
410
430
attn_processors = {k : v .to (device = self .device , dtype = self .dtype ) for k , v in attn_processors .items ()}
431
+ ff_layers = [(t , l .to (device = self .device , dtype = self .dtype )) for t , l in ff_layers ]
411
432
412
433
# set layers
413
434
self .set_attn_processor (attn_processors )
414
435
436
+ # set ff layers
437
+ for target_module , lora_layer in ff_layers :
438
+ if hasattr (target_module , "set_lora_layer" ):
439
+ target_module .set_lora_layer (lora_layer )
440
+
415
441
def save_attn_procs (
416
442
self ,
417
443
save_directory : Union [str , os .PathLike ],
@@ -489,36 +515,6 @@ def save_function(weights, filename):
489
515
save_function (state_dict , os .path .join (save_directory , weight_name ))
490
516
logger .info (f"Model weights saved in { os .path .join (save_directory , weight_name )} " )
491
517
492
- def _load_lora_aux (self , state_dict , network_alpha = None ):
493
- lora_grouped_dict = defaultdict (dict )
494
- for key , value in state_dict .items ():
495
- attn_processor_key , sub_key = "." .join (key .split ("." )[:- 3 ]), "." .join (key .split ("." )[- 3 :])
496
- lora_grouped_dict [attn_processor_key ][sub_key ] = value
497
-
498
- for key , value_dict in lora_grouped_dict .items ():
499
- rank = value_dict ["lora.down.weight" ].shape [0 ]
500
- hidden_size = value_dict ["lora.up.weight" ].shape [0 ]
501
- target_modules = [module for name , module in self .named_modules () if name == key ]
502
- if len (target_modules ) == 0 :
503
- logger .warning (f"Could not find module { key } in the model. Skipping." )
504
- continue
505
-
506
- target_module = target_modules [0 ]
507
- value_dict = {k .replace ("lora." , "" ): v for k , v in value_dict .items ()}
508
-
509
- lora = None
510
- if isinstance (target_module , LoRACompatibleConv ):
511
- lora = LoRAConv2dLayer (hidden_size , hidden_size , rank , network_alpha )
512
- elif isinstance (target_module , LoRACompatibleLinear ):
513
- lora = LoRALinearLayer (target_module .in_features , target_module .out_features , rank , network_alpha )
514
- else :
515
- raise ValueError (f"Module { key } is not a LoRACompatibleConv or LoRACompatibleLinear module." )
516
- lora .load_state_dict (value_dict )
517
- lora .to (device = self .device , dtype = self .dtype )
518
-
519
- # install lora
520
- target_module .lora_layer = lora
521
-
522
518
523
519
class TextualInversionLoaderMixin :
524
520
r"""
@@ -880,18 +876,13 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
880
876
kwargs:
881
877
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
882
878
"""
883
- state_dict , network_alpha , (unet_state_dict_aux , te_state_dict_aux ) = self .lora_state_dict (
884
- pretrained_model_name_or_path_or_dict , ** kwargs
885
- )
886
- self .load_lora_into_unet (
887
- state_dict , network_alpha = network_alpha , unet = self .unet , state_dict_aux = unet_state_dict_aux
888
- )
879
+ state_dict , network_alpha = self .lora_state_dict (pretrained_model_name_or_path_or_dict , ** kwargs )
880
+ self .load_lora_into_unet (state_dict , network_alpha = network_alpha , unet = self .unet )
889
881
self .load_lora_into_text_encoder (
890
882
state_dict ,
891
883
network_alpha = network_alpha ,
892
884
text_encoder = self .text_encoder ,
893
885
lora_scale = self .lora_scale ,
894
- state_dict_aux = te_state_dict_aux ,
895
886
)
896
887
897
888
@classmethod
@@ -1025,14 +1016,13 @@ def lora_state_dict(
1025
1016
1026
1017
# Convert kohya-ss Style LoRA attn procs to diffusers attn procs
1027
1018
network_alpha = None
1028
- auxilary_states = ({}, {})
1029
1019
if all ((k .startswith ("lora_te_" ) or k .startswith ("lora_unet_" )) for k in state_dict .keys ()):
1030
- state_dict , network_alpha , auxilary_states = cls ._convert_kohya_lora_to_diffusers (state_dict )
1020
+ state_dict , network_alpha = cls ._convert_kohya_lora_to_diffusers (state_dict )
1031
1021
1032
- return state_dict , network_alpha , auxilary_states
1022
+ return state_dict , network_alpha
1033
1023
1034
1024
@classmethod
1035
- def load_lora_into_unet (cls , state_dict , network_alpha , unet , state_dict_aux = None ):
1025
+ def load_lora_into_unet (cls , state_dict , network_alpha , unet ):
1036
1026
"""
1037
1027
This will load the LoRA layers specified in `state_dict` into `unet`
1038
1028
@@ -1045,8 +1035,6 @@ def load_lora_into_unet(cls, state_dict, network_alpha, unet, state_dict_aux=Non
1045
1035
See `LoRALinearLayer` for more details.
1046
1036
unet (`UNet2DConditionModel`):
1047
1037
The UNet model to load the LoRA layers into.
1048
- state_dict_aux (`dict`, *optional*):
1049
- A dictionary containing the auxilary state (additional lora state) dict for the unet.
1050
1038
"""
1051
1039
1052
1040
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
@@ -1071,12 +1059,8 @@ def load_lora_into_unet(cls, state_dict, network_alpha, unet, state_dict_aux=Non
1071
1059
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`."
1072
1060
warnings .warn (warn_message )
1073
1061
1074
- if state_dict_aux :
1075
- unet ._load_lora_aux (state_dict_aux , network_alpha = network_alpha )
1076
- unet .aux_state_dict_populated = True
1077
-
1078
1062
@classmethod
1079
- def load_lora_into_text_encoder (cls , state_dict , network_alpha , text_encoder , lora_scale = 1.0 , state_dict_aux = None ):
1063
+ def load_lora_into_text_encoder (cls , state_dict , network_alpha , text_encoder , lora_scale = 1.0 ):
1080
1064
"""
1081
1065
This will load the LoRA layers specified in `state_dict` into `text_encoder`
1082
1066
@@ -1091,8 +1075,6 @@ def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, lo
1091
1075
lora_scale (`float`):
1092
1076
How much to scale the output of the lora linear layer before it is added with the output of the regular
1093
1077
lora layer.
1094
- state_dict_aux (`dict`, *optional*):
1095
- A dictionary containing the auxilary state dict (additional lora state) for the text encoder.
1096
1078
"""
1097
1079
1098
1080
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
@@ -1105,8 +1087,6 @@ def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, lo
1105
1087
text_encoder_lora_state_dict = {
1106
1088
k .replace (f"{ cls .text_encoder_name } ." , "" ): v for k , v in state_dict .items () if k in text_encoder_keys
1107
1089
}
1108
- if state_dict_aux :
1109
- text_encoder_lora_state_dict = {** text_encoder_lora_state_dict , ** state_dict_aux }
1110
1090
1111
1091
if len (text_encoder_lora_state_dict ) > 0 :
1112
1092
logger .info (f"Loading { cls .text_encoder_name } ." )
@@ -1148,8 +1128,7 @@ def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, lo
1148
1128
f"{ name } .out_proj.lora_linear_layer.down.weight"
1149
1129
] = text_encoder_lora_state_dict .pop (f"{ name } .to_out_lora.down.weight" )
1150
1130
1151
- if state_dict_aux :
1152
- for name , _ in text_encoder_aux_modules (text_encoder ):
1131
+ for name , _ in text_encoder_mlp_modules (text_encoder ):
1153
1132
for direction in ["up" , "down" ]:
1154
1133
for layer in ["fc1" , "fc2" ]:
1155
1134
original_key = f"{ name } .{ layer } .lora.{ direction } .weight"
@@ -1163,9 +1142,7 @@ def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, lo
1163
1142
"text_model.encoder.layers.0.self_attn.out_proj.lora_linear_layer.up.weight"
1164
1143
].shape [1 ]
1165
1144
1166
- cls ._modify_text_encoder (
1167
- text_encoder , lora_scale , network_alpha , rank = rank , patch_aux = bool (state_dict_aux )
1168
- )
1145
+ cls ._modify_text_encoder (text_encoder , lora_scale , network_alpha , rank = rank )
1169
1146
1170
1147
# set correct dtype & device
1171
1148
text_encoder_lora_state_dict = {
@@ -1197,13 +1174,10 @@ def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder):
1197
1174
attn_module .v_proj = attn_module .v_proj .regular_linear_layer
1198
1175
attn_module .out_proj = attn_module .out_proj .regular_linear_layer
1199
1176
1200
- if getattr (text_encoder , "aux_state_dict_populated" , False ):
1201
- for _ , aux_module in text_encoder_aux_modules (text_encoder ):
1202
- if isinstance (aux_module .fc1 , PatchedLoraProjection ):
1203
- aux_module .fc1 = aux_module .fc1 .regular_linear_layer
1204
- aux_module .fc2 = aux_module .fc2 .regular_linear_layer
1205
-
1206
- text_encoder .aux_state_dict_populated = False
1177
+ for _ , mlp_module in text_encoder_mlp_modules (text_encoder ):
1178
+ if isinstance (mlp_module .fc1 , PatchedLoraProjection ):
1179
+ mlp_module .fc1 = mlp_module .fc1 .regular_linear_layer
1180
+ mlp_module .fc2 = mlp_module .fc2 .regular_linear_layer
1207
1181
1208
1182
@classmethod
1209
1183
def _modify_text_encoder (
@@ -1213,7 +1187,6 @@ def _modify_text_encoder(
1213
1187
network_alpha = None ,
1214
1188
rank = 4 ,
1215
1189
dtype = None ,
1216
- patch_aux = False ,
1217
1190
):
1218
1191
r"""
1219
1192
Monkey-patches the forward passes of attention modules of the text encoder.
@@ -1245,19 +1218,12 @@ def _modify_text_encoder(
1245
1218
)
1246
1219
lora_parameters .extend (attn_module .out_proj .lora_linear_layer .parameters ())
1247
1220
1248
- if patch_aux :
1249
- for _ , aux_module in text_encoder_aux_modules (text_encoder ):
1250
- aux_module .fc1 = PatchedLoraProjection (
1251
- aux_module .fc1 , lora_scale , network_alpha , rank = rank , dtype = dtype
1252
- )
1253
- lora_parameters .extend (aux_module .fc1 .lora_linear_layer .parameters ())
1254
-
1255
- aux_module .fc2 = PatchedLoraProjection (
1256
- aux_module .fc2 , lora_scale , network_alpha , rank = rank , dtype = dtype
1257
- )
1258
- lora_parameters .extend (aux_module .fc2 .lora_linear_layer .parameters ())
1221
+ for _ , mlp_module in text_encoder_mlp_modules (text_encoder ):
1222
+ mlp_module .fc1 = PatchedLoraProjection (mlp_module .fc1 , lora_scale , network_alpha , rank = rank , dtype = dtype )
1223
+ lora_parameters .extend (mlp_module .fc1 .lora_linear_layer .parameters ())
1259
1224
1260
- text_encoder .aux_state_dict_populated = True
1225
+ mlp_module .fc2 = PatchedLoraProjection (mlp_module .fc2 , lora_scale , network_alpha , rank = rank , dtype = dtype )
1226
+ lora_parameters .extend (mlp_module .fc2 .lora_linear_layer .parameters ())
1261
1227
1262
1228
return lora_parameters
1263
1229
@@ -1343,8 +1309,6 @@ def save_function(weights, filename):
1343
1309
def _convert_kohya_lora_to_diffusers (cls , state_dict ):
1344
1310
unet_state_dict = {}
1345
1311
te_state_dict = {}
1346
- unet_state_dict_aux = {}
1347
- te_state_dict_aux = {}
1348
1312
network_alpha = None
1349
1313
unloaded_keys = []
1350
1314
@@ -1381,11 +1345,11 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict):
1381
1345
unet_state_dict [diffusers_name ] = value
1382
1346
unet_state_dict [diffusers_name .replace (".down." , ".up." )] = state_dict [lora_name_up ]
1383
1347
elif "ff" in diffusers_name :
1384
- unet_state_dict_aux [diffusers_name ] = value
1385
- unet_state_dict_aux [diffusers_name .replace (".down." , ".up." )] = state_dict [lora_name_up ]
1348
+ unet_state_dict [diffusers_name ] = value
1349
+ unet_state_dict [diffusers_name .replace (".down." , ".up." )] = state_dict [lora_name_up ]
1386
1350
elif any (key in diffusers_name for key in ("proj_in" , "proj_out" )):
1387
- unet_state_dict_aux [diffusers_name ] = value
1388
- unet_state_dict_aux [diffusers_name .replace (".down." , ".up." )] = state_dict [lora_name_up ]
1351
+ unet_state_dict [diffusers_name ] = value
1352
+ unet_state_dict [diffusers_name .replace (".down." , ".up." )] = state_dict [lora_name_up ]
1389
1353
1390
1354
elif lora_name .startswith ("lora_te_" ):
1391
1355
diffusers_name = key .replace ("lora_te_" , "" ).replace ("_" , "." )
@@ -1399,8 +1363,8 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict):
1399
1363
te_state_dict [diffusers_name ] = value
1400
1364
te_state_dict [diffusers_name .replace (".down." , ".up." )] = state_dict [lora_name_up ]
1401
1365
elif "mlp" in diffusers_name :
1402
- te_state_dict_aux [diffusers_name ] = value
1403
- te_state_dict_aux [diffusers_name .replace (".down." , ".up." )] = state_dict [lora_name_up ]
1366
+ te_state_dict [diffusers_name ] = value
1367
+ te_state_dict [diffusers_name .replace (".down." , ".up." )] = state_dict [lora_name_up ]
1404
1368
1405
1369
logger .info ("Kohya-style checkpoint detected." )
1406
1370
if len (unloaded_keys ) > 0 :
@@ -1412,7 +1376,7 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict):
1412
1376
unet_state_dict = {f"{ UNET_NAME } .{ module_name } " : params for module_name , params in unet_state_dict .items ()}
1413
1377
te_state_dict = {f"{ TEXT_ENCODER_NAME } .{ module_name } " : params for module_name , params in te_state_dict .items ()}
1414
1378
new_state_dict = {** unet_state_dict , ** te_state_dict }
1415
- return new_state_dict , network_alpha , ( unet_state_dict_aux , te_state_dict_aux )
1379
+ return new_state_dict , network_alpha
1416
1380
1417
1381
def unload_lora_weights (self ):
1418
1382
"""
@@ -1442,11 +1406,9 @@ def unload_lora_weights(self):
1442
1406
[attention_proc_class ] = unet_attention_classes
1443
1407
self .unet .set_attn_processor (regular_attention_classes [attention_proc_class ]())
1444
1408
1445
- if getattr (self .unet , "aux_state_dict_populated" , None ):
1446
- for _ , module in self .unet .named_modules ():
1447
- if hasattr (module , "lora_layer" ) and module .lora_layer is not None :
1448
- module .lora_layer = None
1449
- self .unet .aux_state_dict_populated = False
1409
+ for _ , module in self .unet .named_modules ():
1410
+ if hasattr (module , "set_lora_layer" ):
1411
+ module .set_lora_layer (None )
1450
1412
1451
1413
# Safe to call the following regardless of LoRA.
1452
1414
self ._remove_text_encoder_monkey_patch ()
0 commit comments