@@ -1033,22 +1033,43 @@ def get_lr_weight(self, lora: LoRAModule) -> float:
1033
1033
return lr_weight
1034
1034
1035
1035
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
1036
- def prepare_optimizer_params (self , text_encoder_lr , unet_lr , default_lr ):
1036
+ def prepare_optimizer_params (self , text_encoder_lr , unet_lr , default_lr , , unet_lora_plus_ratio = None , text_encoder_lora_plus_ratio = None ):
1037
1037
self .requires_grad_ (True )
1038
1038
all_params = []
1039
1039
1040
- def enumerate_params (loras : List [LoRAModule ]):
1041
- params = []
1040
+ def assemble_params (loras : List [LoRAModule ], lr , lora_plus_ratio ):
1041
+ param_groups = { "lora" : {}, "plus" : {}}
1042
1042
for lora in loras :
1043
- # params.extend(lora.parameters())
1044
- params .extend (lora .get_trainable_params ())
1043
+ for name , param in lora .get_trainable_named_params ():
1044
+ if lora_plus_ratio is not None and "lora_up" in name :
1045
+ param_groups ["plus" ][f"{ lora .lora_name } .{ name } " ] = param
1046
+ else :
1047
+ param_groups ["lora" ][f"{ lora .lora_name } .{ name } " ] = param
1048
+
1049
+ # assigned_param_groups = ""
1050
+ # for group in param_groups:
1051
+ # assigned_param_groups += f"{group}\n {list(param_groups[group].keys())}\n\n"
1052
+ # logger.info(assigned_param_groups)
1053
+
1054
+ params = []
1055
+ for key in param_groups .keys ():
1056
+ param_data = {"params" : param_groups [key ].values ()}
1057
+ if lr is not None :
1058
+ if key == "plus" :
1059
+ param_data ["lr" ] = lr * lora_plus_ratio
1060
+ else :
1061
+ param_data ["lr" ] = lr
1062
+
1063
+ if ("lr" in param_data ) and (param_data ["lr" ] == 0 ):
1064
+ continue
1065
+
1066
+ params .append (param_data )
1067
+
1045
1068
return params
1046
1069
1047
1070
if self .text_encoder_loras :
1048
- param_data = {"params" : enumerate_params (self .text_encoder_loras )}
1049
- if text_encoder_lr is not None :
1050
- param_data ["lr" ] = text_encoder_lr
1051
- all_params .append (param_data )
1071
+ params = assemble_params (self .text_encoder_loras , text_encoder_lr , text_encoder_lora_plus_ratio )
1072
+ all_params .extend (params )
1052
1073
1053
1074
if self .unet_loras :
1054
1075
if self .block_lr :
@@ -1062,21 +1083,15 @@ def enumerate_params(loras: List[LoRAModule]):
1062
1083
1063
1084
# blockごとにパラメータを設定する
1064
1085
for idx , block_loras in block_idx_to_lora .items ():
1065
- param_data = {"params" : enumerate_params (block_loras )}
1066
-
1067
1086
if unet_lr is not None :
1068
- param_data [ "lr" ] = unet_lr * self .get_lr_weight (block_loras [0 ])
1087
+ params = assemble_params ( block_loras , unet_lr * self .get_lr_weight (block_loras [0 ]), unet_lora_plus_ratio )
1069
1088
elif default_lr is not None :
1070
- param_data ["lr" ] = default_lr * self .get_lr_weight (block_loras [0 ])
1071
- if ("lr" in param_data ) and (param_data ["lr" ] == 0 ):
1072
- continue
1073
- all_params .append (param_data )
1089
+ params = assemble_params (block_loras , default_lr * self .get_lr_weight (block_loras [0 ]), unet_lora_plus_ratio )
1090
+ all_params .extend (params )
1074
1091
1075
1092
else :
1076
- param_data = {"params" : enumerate_params (self .unet_loras )}
1077
- if unet_lr is not None :
1078
- param_data ["lr" ] = unet_lr
1079
- all_params .append (param_data )
1093
+ params = assemble_params (self .unet_loras , unet_lr , unet_lora_plus_ratio )
1094
+ all_params .extend (params )
1080
1095
1081
1096
return all_params
1082
1097
@@ -1093,6 +1108,9 @@ def on_epoch_start(self, text_encoder, unet):
1093
1108
def get_trainable_params (self ):
1094
1109
return self .parameters ()
1095
1110
1111
+ def get_trainable_named_params (self ):
1112
+ return self .named_parameters ()
1113
+
1096
1114
def save_weights (self , file , dtype , metadata ):
1097
1115
if metadata is not None and len (metadata ) == 0 :
1098
1116
metadata = None
0 commit comments