Skip to content

Commit 873d04a

Browse files
committed
Make the auxilary text encoder patching more reliable with custom projector
1 parent 9a7a6e8 commit 873d04a

File tree

2 files changed

+44
-34
lines changed

2 files changed

+44
-34
lines changed

src/diffusers/loaders.py

+41-32
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,20 @@ def text_encoder_attn_modules(text_encoder):
109109
return attn_modules
110110

111111

112+
def text_encoder_aux_modules(text_encoder):
113+
aux_modules = []
114+
115+
if isinstance(text_encoder, CLIPTextModel):
116+
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
117+
mlp_mod = layer.mlp
118+
name = f"text_model.encoder.layers.{i}.mlp"
119+
aux_modules.append((name, mlp_mod))
120+
else:
121+
raise ValueError(f"do not know how to get aux modules for: {text_encoder.__class__.__name__}")
122+
123+
return aux_modules
124+
125+
112126
def text_encoder_lora_state_dict(text_encoder):
113127
state_dict = {}
114128

@@ -1079,6 +1093,7 @@ def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, lo
10791093
text_encoder_lora_state_dict = {
10801094
k.replace(f"{cls.text_encoder_name}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
10811095
}
1096+
text_encoder_lora_state_dict = {**text_encoder_lora_state_dict, **state_dict_aux}
10821097
if len(text_encoder_lora_state_dict) > 0:
10831098
logger.info(f"Loading {cls.text_encoder_name}.")
10841099

@@ -1119,13 +1134,26 @@ def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, lo
11191134
f"{name}.out_proj.lora_linear_layer.down.weight"
11201135
] = text_encoder_lora_state_dict.pop(f"{name}.to_out_lora.down.weight")
11211136

1137+
for name, _ in text_encoder_aux_modules(text_encoder):
1138+
text_encoder_lora_state_dict[
1139+
f"{name}.fc1.lora_linear_layer.up.weight"
1140+
] = text_encoder_lora_state_dict.pop(f"{name}.fc1.lora.up.weight")
1141+
text_encoder_lora_state_dict[
1142+
f"{name}.fc2.lora_linear_layer.up.weight"
1143+
] = text_encoder_lora_state_dict.pop(f"{name}.fc2.lora.up.weight")
1144+
1145+
text_encoder_lora_state_dict[
1146+
f"{name}.fc1.lora_linear_layer.down.weight"
1147+
] = text_encoder_lora_state_dict.pop(f"{name}.fc1.lora.down.weight")
1148+
text_encoder_lora_state_dict[
1149+
f"{name}.fc2.lora_linear_layer.down.weight"
1150+
] = text_encoder_lora_state_dict.pop(f"{name}.fc2.lora.down.weight")
1151+
11221152
rank = text_encoder_lora_state_dict[
11231153
"text_model.encoder.layers.0.self_attn.out_proj.lora_linear_layer.up.weight"
11241154
].shape[1]
11251155

11261156
cls._modify_text_encoder(text_encoder, lora_scale, network_alpha, rank=rank)
1127-
if state_dict_aux:
1128-
cls._load_lora_aux_for_text_encoder(text_encoder, state_dict_aux, network_alpha=network_alpha)
11291157

11301158
# set correct dtype & device
11311159
text_encoder_lora_state_dict = {
@@ -1157,36 +1185,10 @@ def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder):
11571185
attn_module.v_proj = attn_module.v_proj.regular_linear_layer
11581186
attn_module.out_proj = attn_module.out_proj.regular_linear_layer
11591187

1160-
@classmethod
1161-
def _load_lora_aux_for_text_encoder(cls, text_encoder, state_dict, network_alpha=None):
1162-
lora_grouped_dict = defaultdict(dict)
1163-
for key, value in state_dict.items():
1164-
attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
1165-
lora_grouped_dict[attn_processor_key][sub_key] = value
1166-
1167-
for key, value_dict in lora_grouped_dict.items():
1168-
rank = value_dict["lora.down.weight"].shape[0]
1169-
target_modules = [module for name, module in text_encoder.named_modules() if name == key]
1170-
if len(target_modules) == 0:
1171-
logger.warning(f"Could not find module {key} in the model. Skipping.")
1172-
continue
1173-
1174-
target_module = target_modules[0]
1175-
value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()}
1176-
lora_layer = LoRALinearLayer(target_module.in_features, target_module.out_features, rank, network_alpha)
1177-
lora_layer.load_state_dict(value_dict)
1178-
lora_layer.to(device=text_encoder.device, dtype=text_encoder.dtype)
1179-
1180-
old_forward = target_module.forward
1181-
1182-
def make_new_forward(old_forward, lora_layer):
1183-
def new_forward(x):
1184-
return old_forward(x) + lora_layer(x)
1185-
1186-
return new_forward
1187-
1188-
# Monkey-patch.
1189-
target_module.forward = make_new_forward(old_forward, lora_layer)
1188+
for _, aux_module in text_encoder_aux_modules(text_encoder):
1189+
if isinstance(aux_module.fc1, PatchedLoraProjection):
1190+
aux_module.fc1 = aux_module.fc1.regular_linear_layer
1191+
aux_module.fc2 = aux_module.fc2.regular_linear_layer
11901192

11911193
@classmethod
11921194
def _modify_text_encoder(cls, text_encoder, lora_scale=1, network_alpha=None, rank=4, dtype=None):
@@ -1220,6 +1222,13 @@ def _modify_text_encoder(cls, text_encoder, lora_scale=1, network_alpha=None, ra
12201222
)
12211223
lora_parameters.extend(attn_module.out_proj.lora_linear_layer.parameters())
12221224

1225+
for _, aux_module in text_encoder_aux_modules(text_encoder):
1226+
aux_module.fc1 = PatchedLoraProjection(aux_module.fc1, lora_scale, network_alpha, rank=rank, dtype=dtype)
1227+
lora_parameters.extend(aux_module.fc1.lora_linear_layer.parameters())
1228+
1229+
aux_module.fc2 = PatchedLoraProjection(aux_module.fc2, lora_scale, network_alpha, rank=rank, dtype=dtype)
1230+
lora_parameters.extend(aux_module.fc2.lora_linear_layer.parameters())
1231+
12231232
return lora_parameters
12241233

12251234
@classmethod

tests/models/test_layers_utils.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from diffusers.models.attention import GEGLU, AdaLayerNorm, ApproximateGELU
2424
from diffusers.models.embeddings import get_timestep_embedding
25+
from diffusers.models.lora import LinearWithLoRA
2526
from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
2627
from diffusers.models.transformer_2d import Transformer2DModel
2728
from diffusers.utils import torch_device
@@ -482,7 +483,7 @@ def test_spatial_transformer_default_ff_layers(self):
482483

483484
assert spatial_transformer_block.transformer_blocks[0].ff.net[0].__class__ == GEGLU
484485
assert spatial_transformer_block.transformer_blocks[0].ff.net[1].__class__ == nn.Dropout
485-
assert spatial_transformer_block.transformer_blocks[0].ff.net[2].__class__ == nn.Linear
486+
assert spatial_transformer_block.transformer_blocks[0].ff.net[2].__class__ == LinearWithLoRA
486487

487488
dim = 32
488489
inner_dim = 128
@@ -506,7 +507,7 @@ def test_spatial_transformer_geglu_approx_ff_layers(self):
506507

507508
assert spatial_transformer_block.transformer_blocks[0].ff.net[0].__class__ == ApproximateGELU
508509
assert spatial_transformer_block.transformer_blocks[0].ff.net[1].__class__ == nn.Dropout
509-
assert spatial_transformer_block.transformer_blocks[0].ff.net[2].__class__ == nn.Linear
510+
assert spatial_transformer_block.transformer_blocks[0].ff.net[2].__class__ == LinearWithLoRA
510511

511512
dim = 32
512513
inner_dim = 128

0 commit comments

Comments
 (0)