25
25
from huggingface_hub import hf_hub_download
26
26
from torch import nn
27
27
28
+ from .models .lora import LoRACompatibleConv , LoRACompatibleLinear , LoRAConv2dLayer , LoRALinearLayer
28
29
from .utils import (
29
30
DIFFUSERS_CACHE ,
30
31
HF_HUB_OFFLINE ,
56
57
57
58
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
58
59
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
60
+ TOTAL_EXAMPLE_KEYS = 5
59
61
60
62
TEXT_INVERSION_NAME = "learned_embeds.bin"
61
63
TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors"
@@ -105,6 +107,20 @@ def text_encoder_attn_modules(text_encoder):
105
107
return attn_modules
106
108
107
109
110
+ def text_encoder_mlp_modules (text_encoder ):
111
+ mlp_modules = []
112
+
113
+ if isinstance (text_encoder , (CLIPTextModel , CLIPTextModelWithProjection )):
114
+ for i , layer in enumerate (text_encoder .text_model .encoder .layers ):
115
+ mlp_mod = layer .mlp
116
+ name = f"text_model.encoder.layers.{ i } .mlp"
117
+ mlp_modules .append ((name , mlp_mod ))
118
+ else :
119
+ raise ValueError (f"do not know how to get mlp modules for: { text_encoder .__class__ .__name__ } " )
120
+
121
+ return mlp_modules
122
+
123
+
108
124
def text_encoder_lora_state_dict (text_encoder ):
109
125
state_dict = {}
110
126
@@ -304,6 +320,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
304
320
305
321
# fill attn processors
306
322
attn_processors = {}
323
+ non_attn_lora_layers = []
307
324
308
325
is_lora = all ("lora" in k for k in state_dict .keys ())
309
326
is_custom_diffusion = any ("custom_diffusion" in k for k in state_dict .keys ())
@@ -327,13 +344,33 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
327
344
lora_grouped_dict [attn_processor_key ][sub_key ] = value
328
345
329
346
for key , value_dict in lora_grouped_dict .items ():
330
- rank = value_dict ["to_k_lora.down.weight" ].shape [0 ]
331
- hidden_size = value_dict ["to_k_lora.up.weight" ].shape [0 ]
332
-
333
347
attn_processor = self
334
348
for sub_key in key .split ("." ):
335
349
attn_processor = getattr (attn_processor , sub_key )
336
350
351
+ # Process non-attention layers, which don't have to_{k,v,q,out_proj}_lora layers
352
+ # or add_{k,v,q,out_proj}_proj_lora layers.
353
+ if "lora.down.weight" in value_dict :
354
+ rank = value_dict ["lora.down.weight" ].shape [0 ]
355
+ hidden_size = value_dict ["lora.up.weight" ].shape [0 ]
356
+
357
+ if isinstance (attn_processor , LoRACompatibleConv ):
358
+ lora = LoRAConv2dLayer (hidden_size , hidden_size , rank , network_alpha )
359
+ elif isinstance (attn_processor , LoRACompatibleLinear ):
360
+ lora = LoRALinearLayer (
361
+ attn_processor .in_features , attn_processor .out_features , rank , network_alpha
362
+ )
363
+ else :
364
+ raise ValueError (f"Module { key } is not a LoRACompatibleConv or LoRACompatibleLinear module." )
365
+
366
+ value_dict = {k .replace ("lora." , "" ): v for k , v in value_dict .items ()}
367
+ lora .load_state_dict (value_dict )
368
+ non_attn_lora_layers .append ((attn_processor , lora ))
369
+ continue
370
+
371
+ rank = value_dict ["to_k_lora.down.weight" ].shape [0 ]
372
+ hidden_size = value_dict ["to_k_lora.up.weight" ].shape [0 ]
373
+
337
374
if isinstance (
338
375
attn_processor , (AttnAddedKVProcessor , SlicedAttnAddedKVProcessor , AttnAddedKVProcessor2_0 )
339
376
):
@@ -390,10 +427,16 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
390
427
391
428
# set correct dtype & device
392
429
attn_processors = {k : v .to (device = self .device , dtype = self .dtype ) for k , v in attn_processors .items ()}
430
+ non_attn_lora_layers = [(t , l .to (device = self .device , dtype = self .dtype )) for t , l in non_attn_lora_layers ]
393
431
394
432
# set layers
395
433
self .set_attn_processor (attn_processors )
396
434
435
+ # set ff layers
436
+ for target_module , lora_layer in non_attn_lora_layers :
437
+ if hasattr (target_module , "set_lora_layer" ):
438
+ target_module .set_lora_layer (lora_layer )
439
+
397
440
def save_attn_procs (
398
441
self ,
399
442
save_directory : Union [str , os .PathLike ],
@@ -840,7 +883,10 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
840
883
state_dict , network_alpha = self .lora_state_dict (pretrained_model_name_or_path_or_dict , ** kwargs )
841
884
self .load_lora_into_unet (state_dict , network_alpha = network_alpha , unet = self .unet )
842
885
self .load_lora_into_text_encoder (
843
- state_dict , network_alpha = network_alpha , text_encoder = self .text_encoder , lora_scale = self .lora_scale
886
+ state_dict ,
887
+ network_alpha = network_alpha ,
888
+ text_encoder = self .text_encoder ,
889
+ lora_scale = self .lora_scale ,
844
890
)
845
891
846
892
@classmethod
@@ -1049,6 +1095,7 @@ def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, pr
1049
1095
text_encoder_lora_state_dict = {
1050
1096
k .replace (f"{ prefix } ." , "" ): v for k , v in state_dict .items () if k in text_encoder_keys
1051
1097
}
1098
+
1052
1099
if len (text_encoder_lora_state_dict ) > 0 :
1053
1100
logger .info (f"Loading { prefix } ." )
1054
1101
@@ -1092,8 +1139,9 @@ def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, pr
1092
1139
rank = text_encoder_lora_state_dict [
1093
1140
"text_model.encoder.layers.0.self_attn.out_proj.lora_linear_layer.up.weight"
1094
1141
].shape [1 ]
1142
+ patch_mlp = any (".mlp." in key for key in text_encoder_lora_state_dict .keys ())
1095
1143
1096
- cls ._modify_text_encoder (text_encoder , lora_scale , network_alpha , rank = rank )
1144
+ cls ._modify_text_encoder (text_encoder , lora_scale , network_alpha , rank = rank , patch_mlp = patch_mlp )
1097
1145
1098
1146
# set correct dtype & device
1099
1147
text_encoder_lora_state_dict = {
@@ -1125,8 +1173,21 @@ def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder):
1125
1173
attn_module .v_proj = attn_module .v_proj .regular_linear_layer
1126
1174
attn_module .out_proj = attn_module .out_proj .regular_linear_layer
1127
1175
1176
+ for _ , mlp_module in text_encoder_mlp_modules (text_encoder ):
1177
+ if isinstance (mlp_module .fc1 , PatchedLoraProjection ):
1178
+ mlp_module .fc1 = mlp_module .fc1 .regular_linear_layer
1179
+ mlp_module .fc2 = mlp_module .fc2 .regular_linear_layer
1180
+
1128
1181
@classmethod
1129
- def _modify_text_encoder (cls , text_encoder , lora_scale = 1 , network_alpha = None , rank = 4 , dtype = None ):
1182
+ def _modify_text_encoder (
1183
+ cls ,
1184
+ text_encoder ,
1185
+ lora_scale = 1 ,
1186
+ network_alpha = None ,
1187
+ rank = 4 ,
1188
+ dtype = None ,
1189
+ patch_mlp = False ,
1190
+ ):
1130
1191
r"""
1131
1192
Monkey-patches the forward passes of attention modules of the text encoder.
1132
1193
"""
@@ -1157,6 +1218,18 @@ def _modify_text_encoder(cls, text_encoder, lora_scale=1, network_alpha=None, ra
1157
1218
)
1158
1219
lora_parameters .extend (attn_module .out_proj .lora_linear_layer .parameters ())
1159
1220
1221
+ if patch_mlp :
1222
+ for _ , mlp_module in text_encoder_mlp_modules (text_encoder ):
1223
+ mlp_module .fc1 = PatchedLoraProjection (
1224
+ mlp_module .fc1 , lora_scale , network_alpha , rank = rank , dtype = dtype
1225
+ )
1226
+ lora_parameters .extend (mlp_module .fc1 .lora_linear_layer .parameters ())
1227
+
1228
+ mlp_module .fc2 = PatchedLoraProjection (
1229
+ mlp_module .fc2 , lora_scale , network_alpha , rank = rank , dtype = dtype
1230
+ )
1231
+ lora_parameters .extend (mlp_module .fc2 .lora_linear_layer .parameters ())
1232
+
1160
1233
return lora_parameters
1161
1234
1162
1235
@classmethod
@@ -1261,9 +1334,12 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict):
1261
1334
unet_state_dict = {}
1262
1335
te_state_dict = {}
1263
1336
network_alpha = None
1337
+ unloaded_keys = []
1264
1338
1265
1339
for key , value in state_dict .items ():
1266
- if "lora_down" in key :
1340
+ if "hada" in key or "skip" in key :
1341
+ unloaded_keys .append (key )
1342
+ elif "lora_down" in key :
1267
1343
lora_name = key .split ("." )[0 ]
1268
1344
lora_name_up = lora_name + ".lora_up.weight"
1269
1345
lora_name_alpha = lora_name + ".alpha"
@@ -1284,12 +1360,21 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict):
1284
1360
diffusers_name = diffusers_name .replace ("to.k.lora" , "to_k_lora" )
1285
1361
diffusers_name = diffusers_name .replace ("to.v.lora" , "to_v_lora" )
1286
1362
diffusers_name = diffusers_name .replace ("to.out.0.lora" , "to_out_lora" )
1363
+ diffusers_name = diffusers_name .replace ("proj.in" , "proj_in" )
1364
+ diffusers_name = diffusers_name .replace ("proj.out" , "proj_out" )
1287
1365
if "transformer_blocks" in diffusers_name :
1288
1366
if "attn1" in diffusers_name or "attn2" in diffusers_name :
1289
1367
diffusers_name = diffusers_name .replace ("attn1" , "attn1.processor" )
1290
1368
diffusers_name = diffusers_name .replace ("attn2" , "attn2.processor" )
1291
1369
unet_state_dict [diffusers_name ] = value
1292
1370
unet_state_dict [diffusers_name .replace (".down." , ".up." )] = state_dict [lora_name_up ]
1371
+ elif "ff" in diffusers_name :
1372
+ unet_state_dict [diffusers_name ] = value
1373
+ unet_state_dict [diffusers_name .replace (".down." , ".up." )] = state_dict [lora_name_up ]
1374
+ elif any (key in diffusers_name for key in ("proj_in" , "proj_out" )):
1375
+ unet_state_dict [diffusers_name ] = value
1376
+ unet_state_dict [diffusers_name .replace (".down." , ".up." )] = state_dict [lora_name_up ]
1377
+
1293
1378
elif lora_name .startswith ("lora_te_" ):
1294
1379
diffusers_name = key .replace ("lora_te_" , "" ).replace ("_" , "." )
1295
1380
diffusers_name = diffusers_name .replace ("text.model" , "text_model" )
@@ -1301,6 +1386,19 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict):
1301
1386
if "self_attn" in diffusers_name :
1302
1387
te_state_dict [diffusers_name ] = value
1303
1388
te_state_dict [diffusers_name .replace (".down." , ".up." )] = state_dict [lora_name_up ]
1389
+ elif "mlp" in diffusers_name :
1390
+ # Be aware that this is the new diffusers convention and the rest of the code might
1391
+ # not utilize it yet.
1392
+ diffusers_name = diffusers_name .replace (".lora." , ".lora_linear_layer." )
1393
+ te_state_dict [diffusers_name ] = value
1394
+ te_state_dict [diffusers_name .replace (".down." , ".up." )] = state_dict [lora_name_up ]
1395
+
1396
+ logger .info ("Kohya-style checkpoint detected." )
1397
+ if len (unloaded_keys ) > 0 :
1398
+ example_unloaded_keys = ", " .join (x for x in unloaded_keys [:TOTAL_EXAMPLE_KEYS ])
1399
+ logger .warning (
1400
+ f"There are some keys (such as: { example_unloaded_keys } ) in the checkpoints we don't provide support for."
1401
+ )
1304
1402
1305
1403
unet_state_dict = {f"{ UNET_NAME } .{ module_name } " : params for module_name , params in unet_state_dict .items ()}
1306
1404
te_state_dict = {f"{ TEXT_ENCODER_NAME } .{ module_name } " : params for module_name , params in te_state_dict .items ()}
@@ -1346,6 +1444,10 @@ def unload_lora_weights(self):
1346
1444
[attention_proc_class ] = unet_attention_classes
1347
1445
self .unet .set_attn_processor (regular_attention_classes [attention_proc_class ]())
1348
1446
1447
+ for _ , module in self .unet .named_modules ():
1448
+ if hasattr (module , "set_lora_layer" ):
1449
+ module .set_lora_layer (None )
1450
+
1349
1451
# Safe to call the following regardless of LoRA.
1350
1452
self ._remove_text_encoder_monkey_patch ()
1351
1453
0 commit comments