Skip to content

Commit 67a2dec

Browse files
isidenticaltakuma104sayakpaul
authored andcommitted
Load Kohya-ss style LoRAs with auxilary states (huggingface#4147)
* Support to load Kohya-ss style LoRA file format (without restrictions) Co-Authored-By: Takuma Mori <[email protected]> Co-Authored-By: Sayak Paul <[email protected]> * tmp: add sdxl to mlp_modules --------- Co-authored-by: Takuma Mori <[email protected]> Co-authored-by: Sayak Paul <[email protected]>
1 parent 6c8c4a2 commit 67a2dec

File tree

8 files changed

+286
-46
lines changed

8 files changed

+286
-46
lines changed

docs/source/en/training/lora.mdx

+3-2
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@ specific language governing permissions and limitations under the License.
1414

1515
<Tip warning={true}>
1616

17-
Currently, LoRA is only supported for the attention layers of the [`UNet2DConditionalModel`]. We also
18-
support fine-tuning the text encoder for DreamBooth with LoRA in a limited capacity. Fine-tuning the text encoder for DreamBooth generally yields better results, but it can increase compute usage.
17+
This is an experimental feature. Its APIs can change in future.
1918

2019
</Tip>
2120

@@ -286,6 +285,8 @@ You can call [`~diffusers.loaders.LoraLoaderMixin.unload_lora_weights`] on a pip
286285

287286
## Supporting A1111 themed LoRA checkpoints from Diffusers
288287

288+
This support was made possible because of our amazing contributors: [@takuma104](https://github.com/takuma104) and [@isidentical](https://github.com/isidentical).
289+
289290
To provide seamless interoperability with A1111 to our users, we support loading A1111 formatted
290291
LoRA checkpoints using [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] in a limited capacity.
291292
In this section, we explain how to load an A1111 formatted LoRA checkpoint from [CivitAI](https://civitai.com/)

src/diffusers/loaders.py

+109-7
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from huggingface_hub import hf_hub_download
2626
from torch import nn
2727

28+
from .models.lora import LoRACompatibleConv, LoRACompatibleLinear, LoRAConv2dLayer, LoRALinearLayer
2829
from .utils import (
2930
DIFFUSERS_CACHE,
3031
HF_HUB_OFFLINE,
@@ -56,6 +57,7 @@
5657

5758
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
5859
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
60+
TOTAL_EXAMPLE_KEYS = 5
5961

6062
TEXT_INVERSION_NAME = "learned_embeds.bin"
6163
TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors"
@@ -105,6 +107,20 @@ def text_encoder_attn_modules(text_encoder):
105107
return attn_modules
106108

107109

110+
def text_encoder_mlp_modules(text_encoder):
111+
mlp_modules = []
112+
113+
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
114+
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
115+
mlp_mod = layer.mlp
116+
name = f"text_model.encoder.layers.{i}.mlp"
117+
mlp_modules.append((name, mlp_mod))
118+
else:
119+
raise ValueError(f"do not know how to get mlp modules for: {text_encoder.__class__.__name__}")
120+
121+
return mlp_modules
122+
123+
108124
def text_encoder_lora_state_dict(text_encoder):
109125
state_dict = {}
110126

@@ -304,6 +320,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
304320

305321
# fill attn processors
306322
attn_processors = {}
323+
non_attn_lora_layers = []
307324

308325
is_lora = all("lora" in k for k in state_dict.keys())
309326
is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys())
@@ -327,13 +344,33 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
327344
lora_grouped_dict[attn_processor_key][sub_key] = value
328345

329346
for key, value_dict in lora_grouped_dict.items():
330-
rank = value_dict["to_k_lora.down.weight"].shape[0]
331-
hidden_size = value_dict["to_k_lora.up.weight"].shape[0]
332-
333347
attn_processor = self
334348
for sub_key in key.split("."):
335349
attn_processor = getattr(attn_processor, sub_key)
336350

351+
# Process non-attention layers, which don't have to_{k,v,q,out_proj}_lora layers
352+
# or add_{k,v,q,out_proj}_proj_lora layers.
353+
if "lora.down.weight" in value_dict:
354+
rank = value_dict["lora.down.weight"].shape[0]
355+
hidden_size = value_dict["lora.up.weight"].shape[0]
356+
357+
if isinstance(attn_processor, LoRACompatibleConv):
358+
lora = LoRAConv2dLayer(hidden_size, hidden_size, rank, network_alpha)
359+
elif isinstance(attn_processor, LoRACompatibleLinear):
360+
lora = LoRALinearLayer(
361+
attn_processor.in_features, attn_processor.out_features, rank, network_alpha
362+
)
363+
else:
364+
raise ValueError(f"Module {key} is not a LoRACompatibleConv or LoRACompatibleLinear module.")
365+
366+
value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()}
367+
lora.load_state_dict(value_dict)
368+
non_attn_lora_layers.append((attn_processor, lora))
369+
continue
370+
371+
rank = value_dict["to_k_lora.down.weight"].shape[0]
372+
hidden_size = value_dict["to_k_lora.up.weight"].shape[0]
373+
337374
if isinstance(
338375
attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)
339376
):
@@ -390,10 +427,16 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
390427

391428
# set correct dtype & device
392429
attn_processors = {k: v.to(device=self.device, dtype=self.dtype) for k, v in attn_processors.items()}
430+
non_attn_lora_layers = [(t, l.to(device=self.device, dtype=self.dtype)) for t, l in non_attn_lora_layers]
393431

394432
# set layers
395433
self.set_attn_processor(attn_processors)
396434

435+
# set ff layers
436+
for target_module, lora_layer in non_attn_lora_layers:
437+
if hasattr(target_module, "set_lora_layer"):
438+
target_module.set_lora_layer(lora_layer)
439+
397440
def save_attn_procs(
398441
self,
399442
save_directory: Union[str, os.PathLike],
@@ -840,7 +883,10 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
840883
state_dict, network_alpha = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
841884
self.load_lora_into_unet(state_dict, network_alpha=network_alpha, unet=self.unet)
842885
self.load_lora_into_text_encoder(
843-
state_dict, network_alpha=network_alpha, text_encoder=self.text_encoder, lora_scale=self.lora_scale
886+
state_dict,
887+
network_alpha=network_alpha,
888+
text_encoder=self.text_encoder,
889+
lora_scale=self.lora_scale,
844890
)
845891

846892
@classmethod
@@ -1049,6 +1095,7 @@ def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, pr
10491095
text_encoder_lora_state_dict = {
10501096
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
10511097
}
1098+
10521099
if len(text_encoder_lora_state_dict) > 0:
10531100
logger.info(f"Loading {prefix}.")
10541101

@@ -1092,8 +1139,9 @@ def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, pr
10921139
rank = text_encoder_lora_state_dict[
10931140
"text_model.encoder.layers.0.self_attn.out_proj.lora_linear_layer.up.weight"
10941141
].shape[1]
1142+
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
10951143

1096-
cls._modify_text_encoder(text_encoder, lora_scale, network_alpha, rank=rank)
1144+
cls._modify_text_encoder(text_encoder, lora_scale, network_alpha, rank=rank, patch_mlp=patch_mlp)
10971145

10981146
# set correct dtype & device
10991147
text_encoder_lora_state_dict = {
@@ -1125,8 +1173,21 @@ def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder):
11251173
attn_module.v_proj = attn_module.v_proj.regular_linear_layer
11261174
attn_module.out_proj = attn_module.out_proj.regular_linear_layer
11271175

1176+
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
1177+
if isinstance(mlp_module.fc1, PatchedLoraProjection):
1178+
mlp_module.fc1 = mlp_module.fc1.regular_linear_layer
1179+
mlp_module.fc2 = mlp_module.fc2.regular_linear_layer
1180+
11281181
@classmethod
1129-
def _modify_text_encoder(cls, text_encoder, lora_scale=1, network_alpha=None, rank=4, dtype=None):
1182+
def _modify_text_encoder(
1183+
cls,
1184+
text_encoder,
1185+
lora_scale=1,
1186+
network_alpha=None,
1187+
rank=4,
1188+
dtype=None,
1189+
patch_mlp=False,
1190+
):
11301191
r"""
11311192
Monkey-patches the forward passes of attention modules of the text encoder.
11321193
"""
@@ -1157,6 +1218,18 @@ def _modify_text_encoder(cls, text_encoder, lora_scale=1, network_alpha=None, ra
11571218
)
11581219
lora_parameters.extend(attn_module.out_proj.lora_linear_layer.parameters())
11591220

1221+
if patch_mlp:
1222+
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
1223+
mlp_module.fc1 = PatchedLoraProjection(
1224+
mlp_module.fc1, lora_scale, network_alpha, rank=rank, dtype=dtype
1225+
)
1226+
lora_parameters.extend(mlp_module.fc1.lora_linear_layer.parameters())
1227+
1228+
mlp_module.fc2 = PatchedLoraProjection(
1229+
mlp_module.fc2, lora_scale, network_alpha, rank=rank, dtype=dtype
1230+
)
1231+
lora_parameters.extend(mlp_module.fc2.lora_linear_layer.parameters())
1232+
11601233
return lora_parameters
11611234

11621235
@classmethod
@@ -1261,9 +1334,12 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict):
12611334
unet_state_dict = {}
12621335
te_state_dict = {}
12631336
network_alpha = None
1337+
unloaded_keys = []
12641338

12651339
for key, value in state_dict.items():
1266-
if "lora_down" in key:
1340+
if "hada" in key or "skip" in key:
1341+
unloaded_keys.append(key)
1342+
elif "lora_down" in key:
12671343
lora_name = key.split(".")[0]
12681344
lora_name_up = lora_name + ".lora_up.weight"
12691345
lora_name_alpha = lora_name + ".alpha"
@@ -1284,12 +1360,21 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict):
12841360
diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
12851361
diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
12861362
diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora")
1363+
diffusers_name = diffusers_name.replace("proj.in", "proj_in")
1364+
diffusers_name = diffusers_name.replace("proj.out", "proj_out")
12871365
if "transformer_blocks" in diffusers_name:
12881366
if "attn1" in diffusers_name or "attn2" in diffusers_name:
12891367
diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
12901368
diffusers_name = diffusers_name.replace("attn2", "attn2.processor")
12911369
unet_state_dict[diffusers_name] = value
12921370
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
1371+
elif "ff" in diffusers_name:
1372+
unet_state_dict[diffusers_name] = value
1373+
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
1374+
elif any(key in diffusers_name for key in ("proj_in", "proj_out")):
1375+
unet_state_dict[diffusers_name] = value
1376+
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
1377+
12931378
elif lora_name.startswith("lora_te_"):
12941379
diffusers_name = key.replace("lora_te_", "").replace("_", ".")
12951380
diffusers_name = diffusers_name.replace("text.model", "text_model")
@@ -1301,6 +1386,19 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict):
13011386
if "self_attn" in diffusers_name:
13021387
te_state_dict[diffusers_name] = value
13031388
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
1389+
elif "mlp" in diffusers_name:
1390+
# Be aware that this is the new diffusers convention and the rest of the code might
1391+
# not utilize it yet.
1392+
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
1393+
te_state_dict[diffusers_name] = value
1394+
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
1395+
1396+
logger.info("Kohya-style checkpoint detected.")
1397+
if len(unloaded_keys) > 0:
1398+
example_unloaded_keys = ", ".join(x for x in unloaded_keys[:TOTAL_EXAMPLE_KEYS])
1399+
logger.warning(
1400+
f"There are some keys (such as: {example_unloaded_keys}) in the checkpoints we don't provide support for."
1401+
)
13041402

13051403
unet_state_dict = {f"{UNET_NAME}.{module_name}": params for module_name, params in unet_state_dict.items()}
13061404
te_state_dict = {f"{TEXT_ENCODER_NAME}.{module_name}": params for module_name, params in te_state_dict.items()}
@@ -1346,6 +1444,10 @@ def unload_lora_weights(self):
13461444
[attention_proc_class] = unet_attention_classes
13471445
self.unet.set_attn_processor(regular_attention_classes[attention_proc_class]())
13481446

1447+
for _, module in self.unet.named_modules():
1448+
if hasattr(module, "set_lora_layer"):
1449+
module.set_lora_layer(None)
1450+
13491451
# Safe to call the following regardless of LoRA.
13501452
self._remove_text_encoder_monkey_patch()
13511453

src/diffusers/models/attention.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from .activations import get_activation
2222
from .attention_processor import Attention
2323
from .embeddings import CombinedTimestepLabelEmbeddings
24+
from .lora import LoRACompatibleLinear
2425

2526

2627
@maybe_allow_in_graph
@@ -245,7 +246,7 @@ def __init__(
245246
# project dropout
246247
self.net.append(nn.Dropout(dropout))
247248
# project out
248-
self.net.append(nn.Linear(inner_dim, dim_out))
249+
self.net.append(LoRACompatibleLinear(inner_dim, dim_out))
249250
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
250251
if final_dropout:
251252
self.net.append(nn.Dropout(dropout))
@@ -289,7 +290,7 @@ class GEGLU(nn.Module):
289290

290291
def __init__(self, dim_in: int, dim_out: int):
291292
super().__init__()
292-
self.proj = nn.Linear(dim_in, dim_out * 2)
293+
self.proj = LoRACompatibleLinear(dim_in, dim_out * 2)
293294

294295
def gelu(self, gate):
295296
if gate.device.type != "mps":

src/diffusers/models/attention_processor.py

+1-30
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from ..utils import deprecate, logging, maybe_allow_in_graph
2121
from ..utils.import_utils import is_xformers_available
22+
from .lora import LoRALinearLayer
2223

2324

2425
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -505,36 +506,6 @@ def __call__(
505506
return hidden_states
506507

507508

508-
class LoRALinearLayer(nn.Module):
509-
def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None):
510-
super().__init__()
511-
512-
if rank > min(in_features, out_features):
513-
raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")
514-
515-
self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
516-
self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
517-
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
518-
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
519-
self.network_alpha = network_alpha
520-
self.rank = rank
521-
522-
nn.init.normal_(self.down.weight, std=1 / rank)
523-
nn.init.zeros_(self.up.weight)
524-
525-
def forward(self, hidden_states):
526-
orig_dtype = hidden_states.dtype
527-
dtype = self.down.weight.dtype
528-
529-
down_hidden_states = self.down(hidden_states.to(dtype))
530-
up_hidden_states = self.up(down_hidden_states)
531-
532-
if self.network_alpha is not None:
533-
up_hidden_states *= self.network_alpha / self.rank
534-
535-
return up_hidden_states.to(orig_dtype)
536-
537-
538509
class LoRAAttnProcessor(nn.Module):
539510
r"""
540511
Processor for implementing the LoRA attention mechanism.

0 commit comments

Comments
 (0)