|
59 | 59 | import safetensors
|
60 | 60 |
|
61 | 61 | if is_transformers_available():
|
62 |
| - from transformers import CLIPTextModel, PreTrainedModel, PreTrainedTokenizer |
| 62 | + from transformers import CLIPTextModel, CLIPTextModelWithProjection, PreTrainedModel, PreTrainedTokenizer |
63 | 63 |
|
64 | 64 | if is_accelerate_available():
|
65 | 65 | from accelerate import init_empty_weights
|
@@ -128,6 +128,8 @@ def text_encoder_mlp_modules(text_encoder):
|
128 | 128 | mlp_mod = layer.mlp
|
129 | 129 | name = f"text_model.encoder.layers.{i}.mlp"
|
130 | 130 | mlp_modules.append((name, mlp_mod))
|
| 131 | + elif isinstance(text_encoder, CLIPTextModelWithProjection): |
| 132 | + pass # SDXL is not supported yet. |
131 | 133 | else:
|
132 | 134 | raise ValueError(f"do not know how to get mlp modules for: {text_encoder.__class__.__name__}")
|
133 | 135 |
|
@@ -1128,21 +1130,12 @@ def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, lo
|
1128 | 1130 | f"{name}.out_proj.lora_linear_layer.down.weight"
|
1129 | 1131 | ] = text_encoder_lora_state_dict.pop(f"{name}.to_out_lora.down.weight")
|
1130 | 1132 |
|
1131 |
| - for name, _ in text_encoder_mlp_modules(text_encoder): |
1132 |
| - for direction in ["up", "down"]: |
1133 |
| - for layer in ["fc1", "fc2"]: |
1134 |
| - original_key = f"{name}.{layer}.lora.{direction}.weight" |
1135 |
| - replacement_key = f"{name}.{layer}.lora_linear_layer.{direction}.weight" |
1136 |
| - if original_key in text_encoder_lora_state_dict: |
1137 |
| - text_encoder_lora_state_dict[replacement_key] = text_encoder_lora_state_dict.pop( |
1138 |
| - original_key |
1139 |
| - ) |
1140 |
| - |
1141 | 1133 | rank = text_encoder_lora_state_dict[
|
1142 | 1134 | "text_model.encoder.layers.0.self_attn.out_proj.lora_linear_layer.up.weight"
|
1143 | 1135 | ].shape[1]
|
| 1136 | + patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys()) |
1144 | 1137 |
|
1145 |
| - cls._modify_text_encoder(text_encoder, lora_scale, network_alpha, rank=rank) |
| 1138 | + cls._modify_text_encoder(text_encoder, lora_scale, network_alpha, rank=rank, patch_mlp=patch_mlp) |
1146 | 1139 |
|
1147 | 1140 | # set correct dtype & device
|
1148 | 1141 | text_encoder_lora_state_dict = {
|
@@ -1187,6 +1180,7 @@ def _modify_text_encoder(
|
1187 | 1180 | network_alpha=None,
|
1188 | 1181 | rank=4,
|
1189 | 1182 | dtype=None,
|
| 1183 | + patch_mlp=False, |
1190 | 1184 | ):
|
1191 | 1185 | r"""
|
1192 | 1186 | Monkey-patches the forward passes of attention modules of the text encoder.
|
@@ -1218,12 +1212,17 @@ def _modify_text_encoder(
|
1218 | 1212 | )
|
1219 | 1213 | lora_parameters.extend(attn_module.out_proj.lora_linear_layer.parameters())
|
1220 | 1214 |
|
1221 |
| - for _, mlp_module in text_encoder_mlp_modules(text_encoder): |
1222 |
| - mlp_module.fc1 = PatchedLoraProjection(mlp_module.fc1, lora_scale, network_alpha, rank=rank, dtype=dtype) |
1223 |
| - lora_parameters.extend(mlp_module.fc1.lora_linear_layer.parameters()) |
| 1215 | + if patch_mlp: |
| 1216 | + for _, mlp_module in text_encoder_mlp_modules(text_encoder): |
| 1217 | + mlp_module.fc1 = PatchedLoraProjection( |
| 1218 | + mlp_module.fc1, lora_scale, network_alpha, rank=rank, dtype=dtype |
| 1219 | + ) |
| 1220 | + lora_parameters.extend(mlp_module.fc1.lora_linear_layer.parameters()) |
1224 | 1221 |
|
1225 |
| - mlp_module.fc2 = PatchedLoraProjection(mlp_module.fc2, lora_scale, network_alpha, rank=rank, dtype=dtype) |
1226 |
| - lora_parameters.extend(mlp_module.fc2.lora_linear_layer.parameters()) |
| 1222 | + mlp_module.fc2 = PatchedLoraProjection( |
| 1223 | + mlp_module.fc2, lora_scale, network_alpha, rank=rank, dtype=dtype |
| 1224 | + ) |
| 1225 | + lora_parameters.extend(mlp_module.fc2.lora_linear_layer.parameters()) |
1227 | 1226 |
|
1228 | 1227 | return lora_parameters
|
1229 | 1228 |
|
@@ -1363,6 +1362,9 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict):
|
1363 | 1362 | te_state_dict[diffusers_name] = value
|
1364 | 1363 | te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
|
1365 | 1364 | elif "mlp" in diffusers_name:
|
| 1365 | + # Be aware that this is the new diffusers convention and the rest of the code might |
| 1366 | + # not utilize it yet. |
| 1367 | + diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.") |
1366 | 1368 | te_state_dict[diffusers_name] = value
|
1367 | 1369 | te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
|
1368 | 1370 |
|
|
0 commit comments