Skip to content

Commit 424a344

Browse files
committed
Patch MLPs optionally / use the new convention
1 parent 438845b commit 424a344

File tree

1 file changed

+19
-17
lines changed

1 file changed

+19
-17
lines changed

src/diffusers/loaders.py

+19-17
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959
import safetensors
6060

6161
if is_transformers_available():
62-
from transformers import CLIPTextModel, PreTrainedModel, PreTrainedTokenizer
62+
from transformers import CLIPTextModel, CLIPTextModelWithProjection, PreTrainedModel, PreTrainedTokenizer
6363

6464
if is_accelerate_available():
6565
from accelerate import init_empty_weights
@@ -128,6 +128,8 @@ def text_encoder_mlp_modules(text_encoder):
128128
mlp_mod = layer.mlp
129129
name = f"text_model.encoder.layers.{i}.mlp"
130130
mlp_modules.append((name, mlp_mod))
131+
elif isinstance(text_encoder, CLIPTextModelWithProjection):
132+
pass # SDXL is not supported yet.
131133
else:
132134
raise ValueError(f"do not know how to get mlp modules for: {text_encoder.__class__.__name__}")
133135

@@ -1128,21 +1130,12 @@ def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, lo
11281130
f"{name}.out_proj.lora_linear_layer.down.weight"
11291131
] = text_encoder_lora_state_dict.pop(f"{name}.to_out_lora.down.weight")
11301132

1131-
for name, _ in text_encoder_mlp_modules(text_encoder):
1132-
for direction in ["up", "down"]:
1133-
for layer in ["fc1", "fc2"]:
1134-
original_key = f"{name}.{layer}.lora.{direction}.weight"
1135-
replacement_key = f"{name}.{layer}.lora_linear_layer.{direction}.weight"
1136-
if original_key in text_encoder_lora_state_dict:
1137-
text_encoder_lora_state_dict[replacement_key] = text_encoder_lora_state_dict.pop(
1138-
original_key
1139-
)
1140-
11411133
rank = text_encoder_lora_state_dict[
11421134
"text_model.encoder.layers.0.self_attn.out_proj.lora_linear_layer.up.weight"
11431135
].shape[1]
1136+
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
11441137

1145-
cls._modify_text_encoder(text_encoder, lora_scale, network_alpha, rank=rank)
1138+
cls._modify_text_encoder(text_encoder, lora_scale, network_alpha, rank=rank, patch_mlp=patch_mlp)
11461139

11471140
# set correct dtype & device
11481141
text_encoder_lora_state_dict = {
@@ -1187,6 +1180,7 @@ def _modify_text_encoder(
11871180
network_alpha=None,
11881181
rank=4,
11891182
dtype=None,
1183+
patch_mlp=False,
11901184
):
11911185
r"""
11921186
Monkey-patches the forward passes of attention modules of the text encoder.
@@ -1218,12 +1212,17 @@ def _modify_text_encoder(
12181212
)
12191213
lora_parameters.extend(attn_module.out_proj.lora_linear_layer.parameters())
12201214

1221-
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
1222-
mlp_module.fc1 = PatchedLoraProjection(mlp_module.fc1, lora_scale, network_alpha, rank=rank, dtype=dtype)
1223-
lora_parameters.extend(mlp_module.fc1.lora_linear_layer.parameters())
1215+
if patch_mlp:
1216+
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
1217+
mlp_module.fc1 = PatchedLoraProjection(
1218+
mlp_module.fc1, lora_scale, network_alpha, rank=rank, dtype=dtype
1219+
)
1220+
lora_parameters.extend(mlp_module.fc1.lora_linear_layer.parameters())
12241221

1225-
mlp_module.fc2 = PatchedLoraProjection(mlp_module.fc2, lora_scale, network_alpha, rank=rank, dtype=dtype)
1226-
lora_parameters.extend(mlp_module.fc2.lora_linear_layer.parameters())
1222+
mlp_module.fc2 = PatchedLoraProjection(
1223+
mlp_module.fc2, lora_scale, network_alpha, rank=rank, dtype=dtype
1224+
)
1225+
lora_parameters.extend(mlp_module.fc2.lora_linear_layer.parameters())
12271226

12281227
return lora_parameters
12291228

@@ -1363,6 +1362,9 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict):
13631362
te_state_dict[diffusers_name] = value
13641363
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
13651364
elif "mlp" in diffusers_name:
1365+
# Be aware that this is the new diffusers convention and the rest of the code might
1366+
# not utilize it yet.
1367+
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
13661368
te_state_dict[diffusers_name] = value
13671369
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
13681370

0 commit comments

Comments
 (0)