Skip to content

Commit c2239a4

Browse files
authored
Merge pull request #1 from isidentical/fixes
[QoL] Small fixes
2 parents 4048fb1 + c26a02d commit c2239a4

File tree

3 files changed

+27
-2
lines changed

3 files changed

+27
-2
lines changed

src/diffusers/loaders.py

+21-1
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262

6363
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
6464
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
65+
TOTAL_EXAMPLE_KEYS = 5
6566

6667
TEXT_INVERSION_NAME = "learned_embeds.bin"
6768
TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors"
@@ -187,6 +188,7 @@ def map_from(module, state_dict, *args, **kwargs):
187188
class UNet2DConditionLoadersMixin:
188189
text_encoder_name = TEXT_ENCODER_NAME
189190
unet_name = UNET_NAME
191+
aux_state_dict_populated = None
190192

191193
def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
192194
r"""
@@ -1062,6 +1064,7 @@ def load_lora_into_unet(cls, state_dict, network_alpha, unet, state_dict_aux=Non
10621064

10631065
if state_dict_aux:
10641066
unet._load_lora_aux(state_dict_aux, network_alpha=network_alpha)
1067+
unet.aux_state_dict_populated = True
10651068

10661069
@classmethod
10671070
def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, lora_scale=1.0, state_dict_aux=None):
@@ -1314,9 +1317,12 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict):
13141317
unet_state_dict_aux = {}
13151318
te_state_dict_aux = {}
13161319
network_alpha = None
1320+
unloaded_keys = []
13171321

13181322
for key, value in state_dict.items():
1319-
if "lora_down" in key:
1323+
if "hada" in key or "skip" in key:
1324+
unloaded_keys.append(key)
1325+
elif "lora_down" in key:
13201326
lora_name = key.split(".")[0]
13211327
lora_name_up = lora_name + ".lora_up.weight"
13221328
lora_name_alpha = lora_name + ".alpha"
@@ -1351,6 +1357,7 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict):
13511357
elif any(key in diffusers_name for key in ("proj_in", "proj_out")):
13521358
unet_state_dict_aux[diffusers_name] = value
13531359
unet_state_dict_aux[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
1360+
13541361
elif lora_name.startswith("lora_te_"):
13551362
diffusers_name = key.replace("lora_te_", "").replace("_", ".")
13561363
diffusers_name = diffusers_name.replace("text.model", "text_model")
@@ -1366,6 +1373,13 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict):
13661373
te_state_dict_aux[diffusers_name] = value
13671374
te_state_dict_aux[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
13681375

1376+
logger.info("Kohya-style checkpoint detected.")
1377+
if len(unloaded_keys) > 0:
1378+
example_unloaded_keys = ", ".join(x for x in unloaded_keys[:TOTAL_EXAMPLE_KEYS])
1379+
logger.warning(
1380+
f"There are some keys (such as: {example_unloaded_keys}) in the checkpoints we don't provide support for."
1381+
)
1382+
13691383
unet_state_dict = {f"{UNET_NAME}.{module_name}": params for module_name, params in unet_state_dict.items()}
13701384
te_state_dict = {f"{TEXT_ENCODER_NAME}.{module_name}": params for module_name, params in te_state_dict.items()}
13711385
new_state_dict = {**unet_state_dict, **te_state_dict}
@@ -1400,6 +1414,12 @@ def unload_lora_weights(self):
14001414
else:
14011415
self.unet.set_default_attn_processor()
14021416

1417+
if self.unet.aux_state_dict_populated:
1418+
for _, module in self.unet.named_modules():
1419+
if hasattr(module, "old_forward") and module.old_forward is not None:
1420+
module.forward = module.old_forward
1421+
self.unet.aux_state_dict_populated = False
1422+
14031423
# Safe to call the following regardless of LoRA.
14041424
self._remove_text_encoder_monkey_patch()
14051425

src/diffusers/models/lora.py

+4
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,13 @@ class Conv2dWithLoRA(nn.Conv2d):
8787
def __init__(self, *args, lora_layer: Optional[LoRAConv2dLayer] = None, **kwargs):
8888
super().__init__(*args, **kwargs)
8989
self.lora_layer = lora_layer
90+
self.old_forward = None
9091

9192
def forward(self, x):
9293
if self.lora_layer is None:
9394
return super().forward(x)
9495
else:
96+
self.old_forward = super().forward
9597
return super().forward(x) + self.lora_layer(x)
9698

9799

@@ -103,9 +105,11 @@ class LinearWithLoRA(nn.Linear):
103105
def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs):
104106
super().__init__(*args, **kwargs)
105107
self.lora_layer = lora_layer
108+
self.old_forward = None
106109

107110
def forward(self, x):
108111
if self.lora_layer is None:
109112
return super().forward(x)
110113
else:
114+
self.old_forward = super().forward
111115
return super().forward(x) + self.lora_layer(x)

tests/models/test_lora_layers.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -554,7 +554,7 @@ def test_a1111(self):
554554

555555
images = images[0, -3:, -3:, -1].flatten()
556556

557-
expected = np.array([0.3743, 0.3893, 0.3835, 0.3891, 0.3949, 0.3649, 0.3858, 0.3802, 0.3245])
557+
expected = np.array([0.3636, 0.3708, 0.3694, 0.3679, 0.3829, 0.3677, 0.3692, 0.3688, 0.3292])
558558

559559
self.assertTrue(np.allclose(images, expected, atol=1e-4))
560560

@@ -594,6 +594,7 @@ def test_unload_lora(self):
594594
lora_filename = "Colored_Icons_by_vizsumit.safetensors"
595595

596596
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
597+
generator = torch.manual_seed(0)
597598
lora_images = pipe(
598599
prompt, output_type="np", generator=generator, num_inference_steps=num_inference_steps
599600
).images

0 commit comments

Comments
 (0)