Skip to content

Commit f712f56

Browse files
sayakpaulpatrickvonplaten
authored and
Jimmy
committed
[LoRA] Fix SDXL text encoder LoRAs (huggingface#4371)
* temporarily disable text encoder loras. * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debbuging. * modify doc. * rename tests. * print slices. * fix: assertions * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> --------- Co-authored-by: Patrick von Platen <[email protected]>
1 parent d3feaee commit f712f56

File tree

3 files changed

+37
-29
lines changed

3 files changed

+37
-29
lines changed

docs/source/en/training/lora.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -401,5 +401,4 @@ Thanks to [@isidentical](https://github.com/isidentical) for helping us on integ
401401

402402
### Known limitations specific to the Kohya-styled LoRAs
403403

404-
* SDXL LoRAs that have both the text encoders are currently leading to weird results. We're actively investigating the issue.
405404
* When images don't looks similar to other UIs such ComfyUI, it can be beacause of multiple reasons as explained [here](https://github.com/huggingface/diffusers/pull/4287/#issuecomment-1655110736).

src/diffusers/loaders.py

Lines changed: 35 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import copy
1415
import os
1516
import re
1617
import warnings
@@ -258,6 +259,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
258259
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
259260
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
260261
network_alphas = kwargs.pop("network_alphas", None)
262+
is_network_alphas_none = network_alphas is None
261263

262264
if use_safetensors and not is_safetensors_available():
263265
raise ValueError(
@@ -349,13 +351,20 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
349351

350352
# Create another `mapped_network_alphas` dictionary so that we can properly map them.
351353
if network_alphas is not None:
352-
for k in network_alphas:
354+
network_alphas_ = copy.deepcopy(network_alphas)
355+
for k in network_alphas_:
353356
if k.replace(".alpha", "") in key:
354-
mapped_network_alphas.update({attn_processor_key: network_alphas[k]})
357+
mapped_network_alphas.update({attn_processor_key: network_alphas.pop(k)})
358+
359+
if not is_network_alphas_none:
360+
if len(network_alphas) > 0:
361+
raise ValueError(
362+
f"The `network_alphas` has to be empty at this point but has the following keys \n\n {', '.join(network_alphas.keys())}"
363+
)
355364

356365
if len(state_dict) > 0:
357366
raise ValueError(
358-
f"The state_dict has to be empty at this point but has the following keys \n\n {', '.join(state_dict.keys())}"
367+
f"The `state_dict` has to be empty at this point but has the following keys \n\n {', '.join(state_dict.keys())}"
359368
)
360369

361370
for key, value_dict in lora_grouped_dict.items():
@@ -434,14 +443,6 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
434443
v_hidden_size=hidden_size_mapping.get("to_v_lora.up.weight"),
435444
out_rank=rank_mapping.get("to_out_lora.down.weight"),
436445
out_hidden_size=hidden_size_mapping.get("to_out_lora.up.weight"),
437-
# rank=rank_mapping.get("to_k_lora.down.weight", None),
438-
# hidden_size=hidden_size_mapping.get("to_k_lora.up.weight", None),
439-
# q_rank=rank_mapping.get("to_q_lora.down.weight", None),
440-
# q_hidden_size=hidden_size_mapping.get("to_q_lora.up.weight", None),
441-
# v_rank=rank_mapping.get("to_v_lora.down.weight", None),
442-
# v_hidden_size=hidden_size_mapping.get("to_v_lora.up.weight", None),
443-
# out_rank=rank_mapping.get("to_out_lora.down.weight", None),
444-
# out_hidden_size=hidden_size_mapping.get("to_out_lora.up.weight", None),
445446
)
446447
else:
447448
attn_processors[key] = attn_processor_class(
@@ -496,9 +497,6 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
496497
# set ff layers
497498
for target_module, lora_layer in non_attn_lora_layers:
498499
target_module.set_lora_layer(lora_layer)
499-
# It should raise an error if we don't have a set lora here
500-
# if hasattr(target_module, "set_lora_layer"):
501-
# target_module.set_lora_layer(lora_layer)
502500

503501
def save_attn_procs(
504502
self,
@@ -1251,9 +1249,10 @@ def load_lora_into_text_encoder(cls, state_dict, network_alphas, text_encoder, p
12511249
keys = list(state_dict.keys())
12521250
prefix = cls.text_encoder_name if prefix is None else prefix
12531251

1252+
# Safe prefix to check with.
12541253
if any(cls.text_encoder_name in key for key in keys):
12551254
# Load the layers corresponding to text encoder and make necessary adjustments.
1256-
text_encoder_keys = [k for k in keys if k.startswith(prefix)]
1255+
text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
12571256
text_encoder_lora_state_dict = {
12581257
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
12591258
}
@@ -1303,6 +1302,14 @@ def load_lora_into_text_encoder(cls, state_dict, network_alphas, text_encoder, p
13031302
].shape[1]
13041303
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
13051304

1305+
if network_alphas is not None:
1306+
alpha_keys = [
1307+
k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
1308+
]
1309+
network_alphas = {
1310+
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
1311+
}
1312+
13061313
cls._modify_text_encoder(
13071314
text_encoder,
13081315
lora_scale,
@@ -1364,12 +1371,13 @@ def _modify_text_encoder(
13641371

13651372
lora_parameters = []
13661373
network_alphas = {} if network_alphas is None else network_alphas
1374+
is_network_alphas_populated = len(network_alphas) > 0
13671375

13681376
for name, attn_module in text_encoder_attn_modules(text_encoder):
1369-
query_alpha = network_alphas.get(name + ".k.proj.alpha")
1370-
key_alpha = network_alphas.get(name + ".q.proj.alpha")
1371-
value_alpha = network_alphas.get(name + ".v.proj.alpha")
1372-
proj_alpha = network_alphas.get(name + ".out.proj.alpha")
1377+
query_alpha = network_alphas.pop(name + ".to_q_lora.down.weight.alpha", None)
1378+
key_alpha = network_alphas.pop(name + ".to_k_lora.down.weight.alpha", None)
1379+
value_alpha = network_alphas.pop(name + ".to_v_lora.down.weight.alpha", None)
1380+
out_alpha = network_alphas.pop(name + ".to_out_lora.down.weight.alpha", None)
13731381

13741382
attn_module.q_proj = PatchedLoraProjection(
13751383
attn_module.q_proj, lora_scale, network_alpha=query_alpha, rank=rank, dtype=dtype
@@ -1387,14 +1395,14 @@ def _modify_text_encoder(
13871395
lora_parameters.extend(attn_module.v_proj.lora_linear_layer.parameters())
13881396

13891397
attn_module.out_proj = PatchedLoraProjection(
1390-
attn_module.out_proj, lora_scale, network_alpha=proj_alpha, rank=rank, dtype=dtype
1398+
attn_module.out_proj, lora_scale, network_alpha=out_alpha, rank=rank, dtype=dtype
13911399
)
13921400
lora_parameters.extend(attn_module.out_proj.lora_linear_layer.parameters())
13931401

13941402
if patch_mlp:
13951403
for name, mlp_module in text_encoder_mlp_modules(text_encoder):
1396-
fc1_alpha = network_alphas.get(name + ".fc1.alpha")
1397-
fc2_alpha = network_alphas.get(name + ".fc2.alpha")
1404+
fc1_alpha = network_alphas.pop(name + ".fc1.lora_linear_layer.down.weight.alpha")
1405+
fc2_alpha = network_alphas.pop(name + ".fc2.lora_linear_layer.down.weight.alpha")
13981406

13991407
mlp_module.fc1 = PatchedLoraProjection(
14001408
mlp_module.fc1, lora_scale, network_alpha=fc1_alpha, rank=rank, dtype=dtype
@@ -1406,6 +1414,11 @@ def _modify_text_encoder(
14061414
)
14071415
lora_parameters.extend(mlp_module.fc2.lora_linear_layer.parameters())
14081416

1417+
if is_network_alphas_populated and len(network_alphas) > 0:
1418+
raise ValueError(
1419+
f"The `network_alphas` has to be empty at this point but has the following keys \n\n {', '.join(network_alphas.keys())}"
1420+
)
1421+
14091422
return lora_parameters
14101423

14111424
@classmethod
@@ -1519,10 +1532,6 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict):
15191532
lora_name_up = lora_name + ".lora_up.weight"
15201533
lora_name_alpha = lora_name + ".alpha"
15211534

1522-
# if lora_name_alpha in state_dict:
1523-
# alpha = state_dict.pop(lora_name_alpha).item()
1524-
# network_alphas.update({lora_name_alpha: alpha})
1525-
15261535
if lora_name.startswith("lora_unet_"):
15271536
diffusers_name = key.replace("lora_unet_", "").replace("_", ".")
15281537

tests/models/test_lora_layers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -737,7 +737,7 @@ def test_a1111(self):
737737
).images
738738

739739
images = images[0, -3:, -3:, -1].flatten()
740-
expected = np.array([0.3725, 0.3767, 0.3761, 0.3796, 0.3827, 0.3763, 0.3831, 0.3809, 0.3392])
740+
expected = np.array([0.3636, 0.3708, 0.3694, 0.3679, 0.3829, 0.3677, 0.3692, 0.3688, 0.3292])
741741

742742
self.assertTrue(np.allclose(images, expected, atol=1e-4))
743743

@@ -760,7 +760,7 @@ def test_vanilla_funetuning(self):
760760

761761
self.assertTrue(np.allclose(images, expected, atol=1e-4))
762762

763-
def test_unload_lora(self):
763+
def test_unload_kohya_lora(self):
764764
generator = torch.manual_seed(0)
765765
prompt = "masterpiece, best quality, mountain"
766766
num_inference_steps = 2

0 commit comments

Comments
 (0)