62
62
63
63
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
64
64
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
65
+ TOTAL_EXAMPLE_KEYS = 5
65
66
66
67
TEXT_INVERSION_NAME = "learned_embeds.bin"
67
68
TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors"
@@ -187,6 +188,7 @@ def map_from(module, state_dict, *args, **kwargs):
187
188
class UNet2DConditionLoadersMixin :
188
189
text_encoder_name = TEXT_ENCODER_NAME
189
190
unet_name = UNET_NAME
191
+ aux_state_dict_populated = None
190
192
191
193
def load_attn_procs (self , pretrained_model_name_or_path_or_dict : Union [str , Dict [str , torch .Tensor ]], ** kwargs ):
192
194
r"""
@@ -1062,6 +1064,7 @@ def load_lora_into_unet(cls, state_dict, network_alpha, unet, state_dict_aux=Non
1062
1064
1063
1065
if state_dict_aux :
1064
1066
unet ._load_lora_aux (state_dict_aux , network_alpha = network_alpha )
1067
+ unet .aux_state_dict_populated = True
1065
1068
1066
1069
@classmethod
1067
1070
def load_lora_into_text_encoder (cls , state_dict , network_alpha , text_encoder , lora_scale = 1.0 , state_dict_aux = None ):
@@ -1314,9 +1317,12 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict):
1314
1317
unet_state_dict_aux = {}
1315
1318
te_state_dict_aux = {}
1316
1319
network_alpha = None
1320
+ unloaded_keys = []
1317
1321
1318
1322
for key , value in state_dict .items ():
1319
- if "lora_down" in key :
1323
+ if "hada" in key or "skip" in key :
1324
+ unloaded_keys .append (key )
1325
+ elif "lora_down" in key :
1320
1326
lora_name = key .split ("." )[0 ]
1321
1327
lora_name_up = lora_name + ".lora_up.weight"
1322
1328
lora_name_alpha = lora_name + ".alpha"
@@ -1351,6 +1357,7 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict):
1351
1357
elif any (key in diffusers_name for key in ("proj_in" , "proj_out" )):
1352
1358
unet_state_dict_aux [diffusers_name ] = value
1353
1359
unet_state_dict_aux [diffusers_name .replace (".down." , ".up." )] = state_dict [lora_name_up ]
1360
+
1354
1361
elif lora_name .startswith ("lora_te_" ):
1355
1362
diffusers_name = key .replace ("lora_te_" , "" ).replace ("_" , "." )
1356
1363
diffusers_name = diffusers_name .replace ("text.model" , "text_model" )
@@ -1366,6 +1373,13 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict):
1366
1373
te_state_dict_aux [diffusers_name ] = value
1367
1374
te_state_dict_aux [diffusers_name .replace (".down." , ".up." )] = state_dict [lora_name_up ]
1368
1375
1376
+ logger .info ("Kohya-style checkpoint detected." )
1377
+ if len (unloaded_keys ) > 0 :
1378
+ example_unloaded_keys = ", " .join (x for x in unloaded_keys [:TOTAL_EXAMPLE_KEYS ])
1379
+ logger .warning (
1380
+ f"There are some keys (such as: { example_unloaded_keys } ) in the checkpoints we don't provide support for."
1381
+ )
1382
+
1369
1383
unet_state_dict = {f"{ UNET_NAME } .{ module_name } " : params for module_name , params in unet_state_dict .items ()}
1370
1384
te_state_dict = {f"{ TEXT_ENCODER_NAME } .{ module_name } " : params for module_name , params in te_state_dict .items ()}
1371
1385
new_state_dict = {** unet_state_dict , ** te_state_dict }
@@ -1400,6 +1414,12 @@ def unload_lora_weights(self):
1400
1414
else :
1401
1415
self .unet .set_default_attn_processor ()
1402
1416
1417
+ if self .unet .aux_state_dict_populated :
1418
+ for _ , module in self .unet .named_modules ():
1419
+ if hasattr (module , "old_forward" ) and module .old_forward is not None :
1420
+ module .forward = module .old_forward
1421
+ self .unet .aux_state_dict_populated = False
1422
+
1403
1423
# Safe to call the following regardless of LoRA.
1404
1424
self ._remove_text_encoder_monkey_patch ()
1405
1425
0 commit comments