11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
+ import copy
14
15
import os
15
16
import re
16
17
import warnings
@@ -258,6 +259,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
258
259
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
259
260
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
260
261
network_alphas = kwargs .pop ("network_alphas" , None )
262
+ is_network_alphas_none = network_alphas is None
261
263
262
264
if use_safetensors and not is_safetensors_available ():
263
265
raise ValueError (
@@ -349,13 +351,20 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
349
351
350
352
# Create another `mapped_network_alphas` dictionary so that we can properly map them.
351
353
if network_alphas is not None :
352
- for k in network_alphas :
354
+ network_alphas_ = copy .deepcopy (network_alphas )
355
+ for k in network_alphas_ :
353
356
if k .replace (".alpha" , "" ) in key :
354
- mapped_network_alphas .update ({attn_processor_key : network_alphas [k ]})
357
+ mapped_network_alphas .update ({attn_processor_key : network_alphas .pop (k )})
358
+
359
+ if not is_network_alphas_none :
360
+ if len (network_alphas ) > 0 :
361
+ raise ValueError (
362
+ f"The `network_alphas` has to be empty at this point but has the following keys \n \n { ', ' .join (network_alphas .keys ())} "
363
+ )
355
364
356
365
if len (state_dict ) > 0 :
357
366
raise ValueError (
358
- f"The state_dict has to be empty at this point but has the following keys \n \n { ', ' .join (state_dict .keys ())} "
367
+ f"The ` state_dict` has to be empty at this point but has the following keys \n \n { ', ' .join (state_dict .keys ())} "
359
368
)
360
369
361
370
for key , value_dict in lora_grouped_dict .items ():
@@ -434,14 +443,6 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
434
443
v_hidden_size = hidden_size_mapping .get ("to_v_lora.up.weight" ),
435
444
out_rank = rank_mapping .get ("to_out_lora.down.weight" ),
436
445
out_hidden_size = hidden_size_mapping .get ("to_out_lora.up.weight" ),
437
- # rank=rank_mapping.get("to_k_lora.down.weight", None),
438
- # hidden_size=hidden_size_mapping.get("to_k_lora.up.weight", None),
439
- # q_rank=rank_mapping.get("to_q_lora.down.weight", None),
440
- # q_hidden_size=hidden_size_mapping.get("to_q_lora.up.weight", None),
441
- # v_rank=rank_mapping.get("to_v_lora.down.weight", None),
442
- # v_hidden_size=hidden_size_mapping.get("to_v_lora.up.weight", None),
443
- # out_rank=rank_mapping.get("to_out_lora.down.weight", None),
444
- # out_hidden_size=hidden_size_mapping.get("to_out_lora.up.weight", None),
445
446
)
446
447
else :
447
448
attn_processors [key ] = attn_processor_class (
@@ -496,9 +497,6 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
496
497
# set ff layers
497
498
for target_module , lora_layer in non_attn_lora_layers :
498
499
target_module .set_lora_layer (lora_layer )
499
- # It should raise an error if we don't have a set lora here
500
- # if hasattr(target_module, "set_lora_layer"):
501
- # target_module.set_lora_layer(lora_layer)
502
500
503
501
def save_attn_procs (
504
502
self ,
@@ -1251,9 +1249,10 @@ def load_lora_into_text_encoder(cls, state_dict, network_alphas, text_encoder, p
1251
1249
keys = list (state_dict .keys ())
1252
1250
prefix = cls .text_encoder_name if prefix is None else prefix
1253
1251
1252
+ # Safe prefix to check with.
1254
1253
if any (cls .text_encoder_name in key for key in keys ):
1255
1254
# Load the layers corresponding to text encoder and make necessary adjustments.
1256
- text_encoder_keys = [k for k in keys if k .startswith (prefix )]
1255
+ text_encoder_keys = [k for k in keys if k .startswith (prefix ) and k . split ( "." )[ 0 ] == prefix ]
1257
1256
text_encoder_lora_state_dict = {
1258
1257
k .replace (f"{ prefix } ." , "" ): v for k , v in state_dict .items () if k in text_encoder_keys
1259
1258
}
@@ -1303,6 +1302,14 @@ def load_lora_into_text_encoder(cls, state_dict, network_alphas, text_encoder, p
1303
1302
].shape [1 ]
1304
1303
patch_mlp = any (".mlp." in key for key in text_encoder_lora_state_dict .keys ())
1305
1304
1305
+ if network_alphas is not None :
1306
+ alpha_keys = [
1307
+ k for k in network_alphas .keys () if k .startswith (prefix ) and k .split ("." )[0 ] == prefix
1308
+ ]
1309
+ network_alphas = {
1310
+ k .replace (f"{ prefix } ." , "" ): v for k , v in network_alphas .items () if k in alpha_keys
1311
+ }
1312
+
1306
1313
cls ._modify_text_encoder (
1307
1314
text_encoder ,
1308
1315
lora_scale ,
@@ -1364,12 +1371,13 @@ def _modify_text_encoder(
1364
1371
1365
1372
lora_parameters = []
1366
1373
network_alphas = {} if network_alphas is None else network_alphas
1374
+ is_network_alphas_populated = len (network_alphas ) > 0
1367
1375
1368
1376
for name , attn_module in text_encoder_attn_modules (text_encoder ):
1369
- query_alpha = network_alphas .get (name + ".k.proj. alpha" )
1370
- key_alpha = network_alphas .get (name + ".q.proj. alpha" )
1371
- value_alpha = network_alphas .get (name + ".v.proj. alpha" )
1372
- proj_alpha = network_alphas .get (name + ".out.proj. alpha" )
1377
+ query_alpha = network_alphas .pop (name + ".to_q_lora.down.weight. alpha" , None )
1378
+ key_alpha = network_alphas .pop (name + ".to_k_lora.down.weight. alpha" , None )
1379
+ value_alpha = network_alphas .pop (name + ".to_v_lora.down.weight. alpha" , None )
1380
+ out_alpha = network_alphas .pop (name + ".to_out_lora.down.weight. alpha" , None )
1373
1381
1374
1382
attn_module .q_proj = PatchedLoraProjection (
1375
1383
attn_module .q_proj , lora_scale , network_alpha = query_alpha , rank = rank , dtype = dtype
@@ -1387,14 +1395,14 @@ def _modify_text_encoder(
1387
1395
lora_parameters .extend (attn_module .v_proj .lora_linear_layer .parameters ())
1388
1396
1389
1397
attn_module .out_proj = PatchedLoraProjection (
1390
- attn_module .out_proj , lora_scale , network_alpha = proj_alpha , rank = rank , dtype = dtype
1398
+ attn_module .out_proj , lora_scale , network_alpha = out_alpha , rank = rank , dtype = dtype
1391
1399
)
1392
1400
lora_parameters .extend (attn_module .out_proj .lora_linear_layer .parameters ())
1393
1401
1394
1402
if patch_mlp :
1395
1403
for name , mlp_module in text_encoder_mlp_modules (text_encoder ):
1396
- fc1_alpha = network_alphas .get (name + ".fc1.alpha" )
1397
- fc2_alpha = network_alphas .get (name + ".fc2.alpha" )
1404
+ fc1_alpha = network_alphas .pop (name + ".fc1.lora_linear_layer.down.weight .alpha" )
1405
+ fc2_alpha = network_alphas .pop (name + ".fc2.lora_linear_layer.down.weight .alpha" )
1398
1406
1399
1407
mlp_module .fc1 = PatchedLoraProjection (
1400
1408
mlp_module .fc1 , lora_scale , network_alpha = fc1_alpha , rank = rank , dtype = dtype
@@ -1406,6 +1414,11 @@ def _modify_text_encoder(
1406
1414
)
1407
1415
lora_parameters .extend (mlp_module .fc2 .lora_linear_layer .parameters ())
1408
1416
1417
+ if is_network_alphas_populated and len (network_alphas ) > 0 :
1418
+ raise ValueError (
1419
+ f"The `network_alphas` has to be empty at this point but has the following keys \n \n { ', ' .join (network_alphas .keys ())} "
1420
+ )
1421
+
1409
1422
return lora_parameters
1410
1423
1411
1424
@classmethod
@@ -1519,10 +1532,6 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict):
1519
1532
lora_name_up = lora_name + ".lora_up.weight"
1520
1533
lora_name_alpha = lora_name + ".alpha"
1521
1534
1522
- # if lora_name_alpha in state_dict:
1523
- # alpha = state_dict.pop(lora_name_alpha).item()
1524
- # network_alphas.update({lora_name_alpha: alpha})
1525
-
1526
1535
if lora_name .startswith ("lora_unet_" ):
1527
1536
diffusers_name = key .replace ("lora_unet_" , "" ).replace ("_" , "." )
1528
1537
0 commit comments