15
15
import torch
16
16
import re
17
17
from library .utils import setup_logging
18
+
18
19
setup_logging ()
19
20
import logging
21
+
20
22
logger = logging .getLogger (__name__ )
21
23
22
24
RE_UPDOWN = re .compile (r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_" )
@@ -504,6 +506,15 @@ def create_network(
504
506
if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None :
505
507
network .set_block_lr_weight (up_lr_weight , mid_lr_weight , down_lr_weight )
506
508
509
+ loraplus_lr_ratio = kwargs .get ("loraplus_lr_ratio" , None )
510
+ loraplus_unet_lr_ratio = kwargs .get ("loraplus_unet_lr_ratio" , None )
511
+ loraplus_text_encoder_lr_ratio = kwargs .get ("loraplus_text_encoder_lr_ratio" , None )
512
+ loraplus_lr_ratio = float (loraplus_lr_ratio ) if loraplus_lr_ratio is not None else None
513
+ loraplus_unet_lr_ratio = float (loraplus_unet_lr_ratio ) if loraplus_unet_lr_ratio is not None else None
514
+ loraplus_text_encoder_lr_ratio = float (loraplus_text_encoder_lr_ratio ) if loraplus_text_encoder_lr_ratio is not None else None
515
+ if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None :
516
+ network .set_loraplus_lr_ratio (loraplus_lr_ratio , loraplus_unet_lr_ratio , loraplus_text_encoder_lr_ratio )
517
+
507
518
return network
508
519
509
520
@@ -529,7 +540,9 @@ def parse_floats(s):
529
540
len (block_dims ) == num_total_blocks
530
541
), f"block_dims must have { num_total_blocks } elements / block_dimsは{ num_total_blocks } 個指定してください"
531
542
else :
532
- logger .warning (f"block_dims is not specified. all dims are set to { network_dim } / block_dimsが指定されていません。すべてのdimは{ network_dim } になります" )
543
+ logger .warning (
544
+ f"block_dims is not specified. all dims are set to { network_dim } / block_dimsが指定されていません。すべてのdimは{ network_dim } になります"
545
+ )
533
546
block_dims = [network_dim ] * num_total_blocks
534
547
535
548
if block_alphas is not None :
@@ -803,21 +816,31 @@ def __init__(
803
816
self .rank_dropout = rank_dropout
804
817
self .module_dropout = module_dropout
805
818
819
+ self .loraplus_lr_ratio = None
820
+ self .loraplus_unet_lr_ratio = None
821
+ self .loraplus_text_encoder_lr_ratio = None
822
+
806
823
if modules_dim is not None :
807
824
logger .info (f"create LoRA network from weights" )
808
825
elif block_dims is not None :
809
826
logger .info (f"create LoRA network from block_dims" )
810
- logger .info (f"neuron dropout: p={ self .dropout } , rank dropout: p={ self .rank_dropout } , module dropout: p={ self .module_dropout } " )
827
+ logger .info (
828
+ f"neuron dropout: p={ self .dropout } , rank dropout: p={ self .rank_dropout } , module dropout: p={ self .module_dropout } "
829
+ )
811
830
logger .info (f"block_dims: { block_dims } " )
812
831
logger .info (f"block_alphas: { block_alphas } " )
813
832
if conv_block_dims is not None :
814
833
logger .info (f"conv_block_dims: { conv_block_dims } " )
815
834
logger .info (f"conv_block_alphas: { conv_block_alphas } " )
816
835
else :
817
836
logger .info (f"create LoRA network. base dim (rank): { lora_dim } , alpha: { alpha } " )
818
- logger .info (f"neuron dropout: p={ self .dropout } , rank dropout: p={ self .rank_dropout } , module dropout: p={ self .module_dropout } " )
837
+ logger .info (
838
+ f"neuron dropout: p={ self .dropout } , rank dropout: p={ self .rank_dropout } , module dropout: p={ self .module_dropout } "
839
+ )
819
840
if self .conv_lora_dim is not None :
820
- logger .info (f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): { self .conv_lora_dim } , alpha: { self .conv_alpha } " )
841
+ logger .info (
842
+ f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): { self .conv_lora_dim } , alpha: { self .conv_alpha } "
843
+ )
821
844
822
845
# create module instances
823
846
def create_modules (
@@ -939,6 +962,11 @@ def create_modules(
939
962
assert lora .lora_name not in names , f"duplicated lora name: { lora .lora_name } "
940
963
names .add (lora .lora_name )
941
964
965
+ def set_loraplus_lr_ratio (self , loraplus_lr_ratio , loraplus_unet_lr_ratio , loraplus_text_encoder_lr_ratio ):
966
+ self .loraplus_lr_ratio = loraplus_lr_ratio
967
+ self .loraplus_unet_lr_ratio = loraplus_unet_lr_ratio
968
+ self .loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio
969
+
942
970
def set_multiplier (self , multiplier ):
943
971
self .multiplier = multiplier
944
972
for lora in self .text_encoder_loras + self .unet_loras :
@@ -1033,15 +1061,7 @@ def get_lr_weight(self, lora: LoRAModule) -> float:
1033
1061
return lr_weight
1034
1062
1035
1063
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
1036
- def prepare_optimizer_params (
1037
- self ,
1038
- text_encoder_lr ,
1039
- unet_lr ,
1040
- default_lr ,
1041
- text_encoder_loraplus_ratio = None ,
1042
- unet_loraplus_ratio = None ,
1043
- loraplus_ratio = None
1044
- ):
1064
+ def prepare_optimizer_params (self , text_encoder_lr , unet_lr , default_lr ):
1045
1065
self .requires_grad_ (True )
1046
1066
all_params = []
1047
1067
@@ -1078,7 +1098,7 @@ def assemble_params(loras, lr, ratio):
1078
1098
params = assemble_params (
1079
1099
self .text_encoder_loras ,
1080
1100
text_encoder_lr if text_encoder_lr is not None else default_lr ,
1081
- text_encoder_loraplus_ratio or loraplus_ratio
1101
+ self . loraplus_text_encoder_lr_ratio or self . loraplus_ratio ,
1082
1102
)
1083
1103
all_params .extend (params )
1084
1104
@@ -1097,15 +1117,15 @@ def assemble_params(loras, lr, ratio):
1097
1117
params = assemble_params (
1098
1118
block_loras ,
1099
1119
(unet_lr if unet_lr is not None else default_lr ) * self .get_lr_weight (block_loras [0 ]),
1100
- unet_loraplus_ratio or loraplus_ratio
1120
+ self . loraplus_unet_lr_ratio or self . loraplus_ratio ,
1101
1121
)
1102
1122
all_params .extend (params )
1103
1123
1104
1124
else :
1105
1125
params = assemble_params (
1106
1126
self .unet_loras ,
1107
1127
unet_lr if unet_lr is not None else default_lr ,
1108
- unet_loraplus_ratio or loraplus_ratio
1128
+ self . loraplus_unet_lr_ratio or self . loraplus_ratio ,
1109
1129
)
1110
1130
all_params .extend (params )
1111
1131
0 commit comments