Skip to content

Commit bcd9aa9

Browse files
committed
Stop using auxilary states
1 parent debf514 commit bcd9aa9

File tree

4 files changed

+69
-101
lines changed

4 files changed

+69
-101
lines changed

examples/dreambooth/train_dreambooth_lora.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -924,7 +924,7 @@ def load_model_hook(models, input_dir):
924924
else:
925925
raise ValueError(f"unexpected save model: {model.__class__}")
926926

927-
lora_state_dict, network_alpha, _ = LoraLoaderMixin.lora_state_dict(input_dir)
927+
lora_state_dict, network_alpha = LoraLoaderMixin.lora_state_dict(input_dir)
928928
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alpha=network_alpha, unet=unet_)
929929
LoraLoaderMixin.load_lora_into_text_encoder(
930930
lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_

examples/dreambooth/train_dreambooth_lora_sdxl.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -836,7 +836,7 @@ def load_model_hook(models, input_dir):
836836
else:
837837
raise ValueError(f"unexpected save model: {model.__class__}")
838838

839-
lora_state_dict, network_alpha, _ = LoraLoaderMixin.lora_state_dict(input_dir)
839+
lora_state_dict, network_alpha = LoraLoaderMixin.lora_state_dict(input_dir)
840840
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alpha=network_alpha, unet=unet_)
841841
LoraLoaderMixin.load_lora_into_text_encoder(
842842
lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_

src/diffusers/loaders.py

+61-99
Original file line numberDiff line numberDiff line change
@@ -120,18 +120,18 @@ def text_encoder_attn_modules(text_encoder):
120120
return attn_modules
121121

122122

123-
def text_encoder_aux_modules(text_encoder):
124-
aux_modules = []
123+
def text_encoder_mlp_modules(text_encoder):
124+
mlp_modules = []
125125

126126
if isinstance(text_encoder, CLIPTextModel):
127127
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
128128
mlp_mod = layer.mlp
129129
name = f"text_model.encoder.layers.{i}.mlp"
130-
aux_modules.append((name, mlp_mod))
130+
mlp_modules.append((name, mlp_mod))
131131
else:
132-
raise ValueError(f"do not know how to get aux modules for: {text_encoder.__class__.__name__}")
132+
raise ValueError(f"do not know how to get mlp modules for: {text_encoder.__class__.__name__}")
133133

134-
return aux_modules
134+
return mlp_modules
135135

136136

137137
def text_encoder_lora_state_dict(text_encoder):
@@ -322,6 +322,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
322322

323323
# fill attn processors
324324
attn_processors = {}
325+
ff_layers = []
325326

326327
is_lora = all("lora" in k for k in state_dict.keys())
327328
is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys())
@@ -345,13 +346,32 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
345346
lora_grouped_dict[attn_processor_key][sub_key] = value
346347

347348
for key, value_dict in lora_grouped_dict.items():
348-
rank = value_dict["to_k_lora.down.weight"].shape[0]
349-
hidden_size = value_dict["to_k_lora.up.weight"].shape[0]
350-
351349
attn_processor = self
352350
for sub_key in key.split("."):
353351
attn_processor = getattr(attn_processor, sub_key)
354352

353+
# Process FF layers
354+
if "lora.down.weight" in value_dict:
355+
rank = value_dict["lora.down.weight"].shape[0]
356+
hidden_size = value_dict["lora.up.weight"].shape[0]
357+
358+
if isinstance(attn_processor, LoRACompatibleConv):
359+
lora = LoRAConv2dLayer(hidden_size, hidden_size, rank, network_alpha)
360+
elif isinstance(attn_processor, LoRACompatibleLinear):
361+
lora = LoRALinearLayer(
362+
attn_processor.in_features, attn_processor.out_features, rank, network_alpha
363+
)
364+
else:
365+
raise ValueError(f"Module {key} is not a LoRACompatibleConv or LoRACompatibleLinear module.")
366+
367+
value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()}
368+
lora.load_state_dict(value_dict)
369+
ff_layers.append((attn_processor, lora))
370+
continue
371+
372+
rank = value_dict["to_k_lora.down.weight"].shape[0]
373+
hidden_size = value_dict["to_k_lora.up.weight"].shape[0]
374+
355375
if isinstance(
356376
attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)
357377
):
@@ -408,10 +428,16 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
408428

409429
# set correct dtype & device
410430
attn_processors = {k: v.to(device=self.device, dtype=self.dtype) for k, v in attn_processors.items()}
431+
ff_layers = [(t, l.to(device=self.device, dtype=self.dtype)) for t, l in ff_layers]
411432

412433
# set layers
413434
self.set_attn_processor(attn_processors)
414435

436+
# set ff layers
437+
for target_module, lora_layer in ff_layers:
438+
if hasattr(target_module, "set_lora_layer"):
439+
target_module.set_lora_layer(lora_layer)
440+
415441
def save_attn_procs(
416442
self,
417443
save_directory: Union[str, os.PathLike],
@@ -489,36 +515,6 @@ def save_function(weights, filename):
489515
save_function(state_dict, os.path.join(save_directory, weight_name))
490516
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
491517

492-
def _load_lora_aux(self, state_dict, network_alpha=None):
493-
lora_grouped_dict = defaultdict(dict)
494-
for key, value in state_dict.items():
495-
attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
496-
lora_grouped_dict[attn_processor_key][sub_key] = value
497-
498-
for key, value_dict in lora_grouped_dict.items():
499-
rank = value_dict["lora.down.weight"].shape[0]
500-
hidden_size = value_dict["lora.up.weight"].shape[0]
501-
target_modules = [module for name, module in self.named_modules() if name == key]
502-
if len(target_modules) == 0:
503-
logger.warning(f"Could not find module {key} in the model. Skipping.")
504-
continue
505-
506-
target_module = target_modules[0]
507-
value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()}
508-
509-
lora = None
510-
if isinstance(target_module, LoRACompatibleConv):
511-
lora = LoRAConv2dLayer(hidden_size, hidden_size, rank, network_alpha)
512-
elif isinstance(target_module, LoRACompatibleLinear):
513-
lora = LoRALinearLayer(target_module.in_features, target_module.out_features, rank, network_alpha)
514-
else:
515-
raise ValueError(f"Module {key} is not a LoRACompatibleConv or LoRACompatibleLinear module.")
516-
lora.load_state_dict(value_dict)
517-
lora.to(device=self.device, dtype=self.dtype)
518-
519-
# install lora
520-
target_module.lora_layer = lora
521-
522518

523519
class TextualInversionLoaderMixin:
524520
r"""
@@ -880,18 +876,13 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
880876
kwargs:
881877
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
882878
"""
883-
state_dict, network_alpha, (unet_state_dict_aux, te_state_dict_aux) = self.lora_state_dict(
884-
pretrained_model_name_or_path_or_dict, **kwargs
885-
)
886-
self.load_lora_into_unet(
887-
state_dict, network_alpha=network_alpha, unet=self.unet, state_dict_aux=unet_state_dict_aux
888-
)
879+
state_dict, network_alpha = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
880+
self.load_lora_into_unet(state_dict, network_alpha=network_alpha, unet=self.unet)
889881
self.load_lora_into_text_encoder(
890882
state_dict,
891883
network_alpha=network_alpha,
892884
text_encoder=self.text_encoder,
893885
lora_scale=self.lora_scale,
894-
state_dict_aux=te_state_dict_aux,
895886
)
896887

897888
@classmethod
@@ -1025,14 +1016,13 @@ def lora_state_dict(
10251016

10261017
# Convert kohya-ss Style LoRA attn procs to diffusers attn procs
10271018
network_alpha = None
1028-
auxilary_states = ({}, {})
10291019
if all((k.startswith("lora_te_") or k.startswith("lora_unet_")) for k in state_dict.keys()):
1030-
state_dict, network_alpha, auxilary_states = cls._convert_kohya_lora_to_diffusers(state_dict)
1020+
state_dict, network_alpha = cls._convert_kohya_lora_to_diffusers(state_dict)
10311021

1032-
return state_dict, network_alpha, auxilary_states
1022+
return state_dict, network_alpha
10331023

10341024
@classmethod
1035-
def load_lora_into_unet(cls, state_dict, network_alpha, unet, state_dict_aux=None):
1025+
def load_lora_into_unet(cls, state_dict, network_alpha, unet):
10361026
"""
10371027
This will load the LoRA layers specified in `state_dict` into `unet`
10381028
@@ -1045,8 +1035,6 @@ def load_lora_into_unet(cls, state_dict, network_alpha, unet, state_dict_aux=Non
10451035
See `LoRALinearLayer` for more details.
10461036
unet (`UNet2DConditionModel`):
10471037
The UNet model to load the LoRA layers into.
1048-
state_dict_aux (`dict`, *optional*):
1049-
A dictionary containing the auxilary state (additional lora state) dict for the unet.
10501038
"""
10511039

10521040
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
@@ -1071,12 +1059,8 @@ def load_lora_into_unet(cls, state_dict, network_alpha, unet, state_dict_aux=Non
10711059
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`."
10721060
warnings.warn(warn_message)
10731061

1074-
if state_dict_aux:
1075-
unet._load_lora_aux(state_dict_aux, network_alpha=network_alpha)
1076-
unet.aux_state_dict_populated = True
1077-
10781062
@classmethod
1079-
def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, lora_scale=1.0, state_dict_aux=None):
1063+
def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, lora_scale=1.0):
10801064
"""
10811065
This will load the LoRA layers specified in `state_dict` into `text_encoder`
10821066
@@ -1091,8 +1075,6 @@ def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, lo
10911075
lora_scale (`float`):
10921076
How much to scale the output of the lora linear layer before it is added with the output of the regular
10931077
lora layer.
1094-
state_dict_aux (`dict`, *optional*):
1095-
A dictionary containing the auxilary state dict (additional lora state) for the text encoder.
10961078
"""
10971079

10981080
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
@@ -1105,8 +1087,6 @@ def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, lo
11051087
text_encoder_lora_state_dict = {
11061088
k.replace(f"{cls.text_encoder_name}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
11071089
}
1108-
if state_dict_aux:
1109-
text_encoder_lora_state_dict = {**text_encoder_lora_state_dict, **state_dict_aux}
11101090

11111091
if len(text_encoder_lora_state_dict) > 0:
11121092
logger.info(f"Loading {cls.text_encoder_name}.")
@@ -1148,8 +1128,7 @@ def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, lo
11481128
f"{name}.out_proj.lora_linear_layer.down.weight"
11491129
] = text_encoder_lora_state_dict.pop(f"{name}.to_out_lora.down.weight")
11501130

1151-
if state_dict_aux:
1152-
for name, _ in text_encoder_aux_modules(text_encoder):
1131+
for name, _ in text_encoder_mlp_modules(text_encoder):
11531132
for direction in ["up", "down"]:
11541133
for layer in ["fc1", "fc2"]:
11551134
original_key = f"{name}.{layer}.lora.{direction}.weight"
@@ -1163,9 +1142,7 @@ def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, lo
11631142
"text_model.encoder.layers.0.self_attn.out_proj.lora_linear_layer.up.weight"
11641143
].shape[1]
11651144

1166-
cls._modify_text_encoder(
1167-
text_encoder, lora_scale, network_alpha, rank=rank, patch_aux=bool(state_dict_aux)
1168-
)
1145+
cls._modify_text_encoder(text_encoder, lora_scale, network_alpha, rank=rank)
11691146

11701147
# set correct dtype & device
11711148
text_encoder_lora_state_dict = {
@@ -1197,13 +1174,10 @@ def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder):
11971174
attn_module.v_proj = attn_module.v_proj.regular_linear_layer
11981175
attn_module.out_proj = attn_module.out_proj.regular_linear_layer
11991176

1200-
if getattr(text_encoder, "aux_state_dict_populated", False):
1201-
for _, aux_module in text_encoder_aux_modules(text_encoder):
1202-
if isinstance(aux_module.fc1, PatchedLoraProjection):
1203-
aux_module.fc1 = aux_module.fc1.regular_linear_layer
1204-
aux_module.fc2 = aux_module.fc2.regular_linear_layer
1205-
1206-
text_encoder.aux_state_dict_populated = False
1177+
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
1178+
if isinstance(mlp_module.fc1, PatchedLoraProjection):
1179+
mlp_module.fc1 = mlp_module.fc1.regular_linear_layer
1180+
mlp_module.fc2 = mlp_module.fc2.regular_linear_layer
12071181

12081182
@classmethod
12091183
def _modify_text_encoder(
@@ -1213,7 +1187,6 @@ def _modify_text_encoder(
12131187
network_alpha=None,
12141188
rank=4,
12151189
dtype=None,
1216-
patch_aux=False,
12171190
):
12181191
r"""
12191192
Monkey-patches the forward passes of attention modules of the text encoder.
@@ -1245,19 +1218,12 @@ def _modify_text_encoder(
12451218
)
12461219
lora_parameters.extend(attn_module.out_proj.lora_linear_layer.parameters())
12471220

1248-
if patch_aux:
1249-
for _, aux_module in text_encoder_aux_modules(text_encoder):
1250-
aux_module.fc1 = PatchedLoraProjection(
1251-
aux_module.fc1, lora_scale, network_alpha, rank=rank, dtype=dtype
1252-
)
1253-
lora_parameters.extend(aux_module.fc1.lora_linear_layer.parameters())
1254-
1255-
aux_module.fc2 = PatchedLoraProjection(
1256-
aux_module.fc2, lora_scale, network_alpha, rank=rank, dtype=dtype
1257-
)
1258-
lora_parameters.extend(aux_module.fc2.lora_linear_layer.parameters())
1221+
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
1222+
mlp_module.fc1 = PatchedLoraProjection(mlp_module.fc1, lora_scale, network_alpha, rank=rank, dtype=dtype)
1223+
lora_parameters.extend(mlp_module.fc1.lora_linear_layer.parameters())
12591224

1260-
text_encoder.aux_state_dict_populated = True
1225+
mlp_module.fc2 = PatchedLoraProjection(mlp_module.fc2, lora_scale, network_alpha, rank=rank, dtype=dtype)
1226+
lora_parameters.extend(mlp_module.fc2.lora_linear_layer.parameters())
12611227

12621228
return lora_parameters
12631229

@@ -1343,8 +1309,6 @@ def save_function(weights, filename):
13431309
def _convert_kohya_lora_to_diffusers(cls, state_dict):
13441310
unet_state_dict = {}
13451311
te_state_dict = {}
1346-
unet_state_dict_aux = {}
1347-
te_state_dict_aux = {}
13481312
network_alpha = None
13491313
unloaded_keys = []
13501314

@@ -1381,11 +1345,11 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict):
13811345
unet_state_dict[diffusers_name] = value
13821346
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
13831347
elif "ff" in diffusers_name:
1384-
unet_state_dict_aux[diffusers_name] = value
1385-
unet_state_dict_aux[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
1348+
unet_state_dict[diffusers_name] = value
1349+
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
13861350
elif any(key in diffusers_name for key in ("proj_in", "proj_out")):
1387-
unet_state_dict_aux[diffusers_name] = value
1388-
unet_state_dict_aux[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
1351+
unet_state_dict[diffusers_name] = value
1352+
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
13891353

13901354
elif lora_name.startswith("lora_te_"):
13911355
diffusers_name = key.replace("lora_te_", "").replace("_", ".")
@@ -1399,8 +1363,8 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict):
13991363
te_state_dict[diffusers_name] = value
14001364
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
14011365
elif "mlp" in diffusers_name:
1402-
te_state_dict_aux[diffusers_name] = value
1403-
te_state_dict_aux[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
1366+
te_state_dict[diffusers_name] = value
1367+
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
14041368

14051369
logger.info("Kohya-style checkpoint detected.")
14061370
if len(unloaded_keys) > 0:
@@ -1412,7 +1376,7 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict):
14121376
unet_state_dict = {f"{UNET_NAME}.{module_name}": params for module_name, params in unet_state_dict.items()}
14131377
te_state_dict = {f"{TEXT_ENCODER_NAME}.{module_name}": params for module_name, params in te_state_dict.items()}
14141378
new_state_dict = {**unet_state_dict, **te_state_dict}
1415-
return new_state_dict, network_alpha, (unet_state_dict_aux, te_state_dict_aux)
1379+
return new_state_dict, network_alpha
14161380

14171381
def unload_lora_weights(self):
14181382
"""
@@ -1442,11 +1406,9 @@ def unload_lora_weights(self):
14421406
[attention_proc_class] = unet_attention_classes
14431407
self.unet.set_attn_processor(regular_attention_classes[attention_proc_class]())
14441408

1445-
if getattr(self.unet, "aux_state_dict_populated", None):
1446-
for _, module in self.unet.named_modules():
1447-
if hasattr(module, "lora_layer") and module.lora_layer is not None:
1448-
module.lora_layer = None
1449-
self.unet.aux_state_dict_populated = False
1409+
for _, module in self.unet.named_modules():
1410+
if hasattr(module, "set_lora_layer"):
1411+
module.set_lora_layer(None)
14501412

14511413
# Safe to call the following regardless of LoRA.
14521414
self._remove_text_encoder_monkey_patch()

src/diffusers/models/lora.py

+6
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ def __init__(self, *args, lora_layer: Optional[LoRAConv2dLayer] = None, **kwargs
8787
super().__init__(*args, **kwargs)
8888
self.lora_layer = lora_layer
8989

90+
def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]):
91+
self.lora_layer = lora_layer
92+
9093
def forward(self, x):
9194
if self.lora_layer is None:
9295
return super().forward(x)
@@ -103,6 +106,9 @@ def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs
103106
super().__init__(*args, **kwargs)
104107
self.lora_layer = lora_layer
105108

109+
def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]):
110+
self.lora_layer = lora_layer
111+
106112
def forward(self, x):
107113
if self.lora_layer is None:
108114
return super().forward(x)

0 commit comments

Comments
 (0)