@@ -109,6 +109,20 @@ def text_encoder_attn_modules(text_encoder):
109
109
return attn_modules
110
110
111
111
112
+ def text_encoder_aux_modules (text_encoder ):
113
+ aux_modules = []
114
+
115
+ if isinstance (text_encoder , CLIPTextModel ):
116
+ for i , layer in enumerate (text_encoder .text_model .encoder .layers ):
117
+ mlp_mod = layer .mlp
118
+ name = f"text_model.encoder.layers.{ i } .mlp"
119
+ aux_modules .append ((name , mlp_mod ))
120
+ else :
121
+ raise ValueError (f"do not know how to get aux modules for: { text_encoder .__class__ .__name__ } " )
122
+
123
+ return aux_modules
124
+
125
+
112
126
def text_encoder_lora_state_dict (text_encoder ):
113
127
state_dict = {}
114
128
@@ -1079,6 +1093,7 @@ def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, lo
1079
1093
text_encoder_lora_state_dict = {
1080
1094
k .replace (f"{ cls .text_encoder_name } ." , "" ): v for k , v in state_dict .items () if k in text_encoder_keys
1081
1095
}
1096
+ text_encoder_lora_state_dict = {** text_encoder_lora_state_dict , ** state_dict_aux }
1082
1097
if len (text_encoder_lora_state_dict ) > 0 :
1083
1098
logger .info (f"Loading { cls .text_encoder_name } ." )
1084
1099
@@ -1119,13 +1134,26 @@ def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, lo
1119
1134
f"{ name } .out_proj.lora_linear_layer.down.weight"
1120
1135
] = text_encoder_lora_state_dict .pop (f"{ name } .to_out_lora.down.weight" )
1121
1136
1137
+ for name , _ in text_encoder_aux_modules (text_encoder ):
1138
+ text_encoder_lora_state_dict [
1139
+ f"{ name } .fc1.lora_linear_layer.up.weight"
1140
+ ] = text_encoder_lora_state_dict .pop (f"{ name } .fc1.lora.up.weight" )
1141
+ text_encoder_lora_state_dict [
1142
+ f"{ name } .fc2.lora_linear_layer.up.weight"
1143
+ ] = text_encoder_lora_state_dict .pop (f"{ name } .fc2.lora.up.weight" )
1144
+
1145
+ text_encoder_lora_state_dict [
1146
+ f"{ name } .fc1.lora_linear_layer.down.weight"
1147
+ ] = text_encoder_lora_state_dict .pop (f"{ name } .fc1.lora.down.weight" )
1148
+ text_encoder_lora_state_dict [
1149
+ f"{ name } .fc2.lora_linear_layer.down.weight"
1150
+ ] = text_encoder_lora_state_dict .pop (f"{ name } .fc2.lora.down.weight" )
1151
+
1122
1152
rank = text_encoder_lora_state_dict [
1123
1153
"text_model.encoder.layers.0.self_attn.out_proj.lora_linear_layer.up.weight"
1124
1154
].shape [1 ]
1125
1155
1126
1156
cls ._modify_text_encoder (text_encoder , lora_scale , network_alpha , rank = rank )
1127
- if state_dict_aux :
1128
- cls ._load_lora_aux_for_text_encoder (text_encoder , state_dict_aux , network_alpha = network_alpha )
1129
1157
1130
1158
# set correct dtype & device
1131
1159
text_encoder_lora_state_dict = {
@@ -1157,36 +1185,10 @@ def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder):
1157
1185
attn_module .v_proj = attn_module .v_proj .regular_linear_layer
1158
1186
attn_module .out_proj = attn_module .out_proj .regular_linear_layer
1159
1187
1160
- @classmethod
1161
- def _load_lora_aux_for_text_encoder (cls , text_encoder , state_dict , network_alpha = None ):
1162
- lora_grouped_dict = defaultdict (dict )
1163
- for key , value in state_dict .items ():
1164
- attn_processor_key , sub_key = "." .join (key .split ("." )[:- 3 ]), "." .join (key .split ("." )[- 3 :])
1165
- lora_grouped_dict [attn_processor_key ][sub_key ] = value
1166
-
1167
- for key , value_dict in lora_grouped_dict .items ():
1168
- rank = value_dict ["lora.down.weight" ].shape [0 ]
1169
- target_modules = [module for name , module in text_encoder .named_modules () if name == key ]
1170
- if len (target_modules ) == 0 :
1171
- logger .warning (f"Could not find module { key } in the model. Skipping." )
1172
- continue
1173
-
1174
- target_module = target_modules [0 ]
1175
- value_dict = {k .replace ("lora." , "" ): v for k , v in value_dict .items ()}
1176
- lora_layer = LoRALinearLayer (target_module .in_features , target_module .out_features , rank , network_alpha )
1177
- lora_layer .load_state_dict (value_dict )
1178
- lora_layer .to (device = text_encoder .device , dtype = text_encoder .dtype )
1179
-
1180
- old_forward = target_module .forward
1181
-
1182
- def make_new_forward (old_forward , lora_layer ):
1183
- def new_forward (x ):
1184
- return old_forward (x ) + lora_layer (x )
1185
-
1186
- return new_forward
1187
-
1188
- # Monkey-patch.
1189
- target_module .forward = make_new_forward (old_forward , lora_layer )
1188
+ for _ , aux_module in text_encoder_aux_modules (text_encoder ):
1189
+ if isinstance (aux_module .fc1 , PatchedLoraProjection ):
1190
+ aux_module .fc1 = aux_module .fc1 .regular_linear_layer
1191
+ aux_module .fc2 = aux_module .fc2 .regular_linear_layer
1190
1192
1191
1193
@classmethod
1192
1194
def _modify_text_encoder (cls , text_encoder , lora_scale = 1 , network_alpha = None , rank = 4 , dtype = None ):
@@ -1220,6 +1222,13 @@ def _modify_text_encoder(cls, text_encoder, lora_scale=1, network_alpha=None, ra
1220
1222
)
1221
1223
lora_parameters .extend (attn_module .out_proj .lora_linear_layer .parameters ())
1222
1224
1225
+ for _ , aux_module in text_encoder_aux_modules (text_encoder ):
1226
+ aux_module .fc1 = PatchedLoraProjection (aux_module .fc1 , lora_scale , network_alpha , rank = rank , dtype = dtype )
1227
+ lora_parameters .extend (aux_module .fc1 .lora_linear_layer .parameters ())
1228
+
1229
+ aux_module .fc2 = PatchedLoraProjection (aux_module .fc2 , lora_scale , network_alpha , rank = rank , dtype = dtype )
1230
+ lora_parameters .extend (aux_module .fc2 .lora_linear_layer .parameters ())
1231
+
1223
1232
return lora_parameters
1224
1233
1225
1234
@classmethod
0 commit comments