@@ -1035,21 +1035,43 @@ def get_lr_weight(self, lora: LoRAModule) -> float:
1035
1035
return lr_weight
1036
1036
1037
1037
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
1038
- def prepare_optimizer_params (self , text_encoder_lr , unet_lr , default_lr ):
1038
+ def prepare_optimizer_params (self , text_encoder_lr , unet_lr , default_lr , unet_lora_plus_ratio = None , text_encoder_lora_plus_ratio = None ):
1039
1039
self .requires_grad_ (True )
1040
1040
all_params = []
1041
1041
1042
- def enumerate_params (loras ):
1043
- params = []
1042
+ def assemble_params (loras , lr , lora_plus_ratio ):
1043
+ param_groups = { "lora" : {}, "plus" : {}}
1044
1044
for lora in loras :
1045
- params .extend (lora .parameters ())
1045
+ for name , param in lora .named_parameters ():
1046
+ if lora_plus_ratio is not None and "lora_up" in name :
1047
+ param_groups ["plus" ][f"{ lora .lora_name } .{ name } " ] = param
1048
+ else :
1049
+ param_groups ["lora" ][f"{ lora .lora_name } .{ name } " ] = param
1050
+
1051
+ # assigned_param_groups = ""
1052
+ # for group in param_groups:
1053
+ # assigned_param_groups += f"{group}\n {list(param_groups[group].keys())}\n\n"
1054
+ # logger.info(assigned_param_groups)
1055
+
1056
+ params = []
1057
+ for key in param_groups .keys ():
1058
+ param_data = {"params" : param_groups [key ].values ()}
1059
+ if lr is not None :
1060
+ if key == "plus" :
1061
+ param_data ["lr" ] = lr * lora_plus_ratio
1062
+ else :
1063
+ param_data ["lr" ] = lr
1064
+
1065
+ if ("lr" in param_data ) and (param_data ["lr" ] == 0 ):
1066
+ continue
1067
+
1068
+ params .append (param_data )
1069
+
1046
1070
return params
1047
1071
1048
1072
if self .text_encoder_loras :
1049
- param_data = {"params" : enumerate_params (self .text_encoder_loras )}
1050
- if text_encoder_lr is not None :
1051
- param_data ["lr" ] = text_encoder_lr
1052
- all_params .append (param_data )
1073
+ params = assemble_params (self .text_encoder_loras , text_encoder_lr , text_encoder_lora_plus_ratio )
1074
+ all_params .extend (params )
1053
1075
1054
1076
if self .unet_loras :
1055
1077
if self .block_lr :
@@ -1063,21 +1085,15 @@ def enumerate_params(loras):
1063
1085
1064
1086
# blockごとにパラメータを設定する
1065
1087
for idx , block_loras in block_idx_to_lora .items ():
1066
- param_data = {"params" : enumerate_params (block_loras )}
1067
-
1068
1088
if unet_lr is not None :
1069
- param_data [ "lr" ] = unet_lr * self .get_lr_weight (block_loras [0 ])
1089
+ params = assemble_params ( block_loras , unet_lr * self .get_lr_weight (block_loras [0 ]), unet_lora_plus_ratio )
1070
1090
elif default_lr is not None :
1071
- param_data ["lr" ] = default_lr * self .get_lr_weight (block_loras [0 ])
1072
- if ("lr" in param_data ) and (param_data ["lr" ] == 0 ):
1073
- continue
1074
- all_params .append (param_data )
1091
+ params = assemble_params (block_loras , default_lr * self .get_lr_weight (block_loras [0 ]), unet_lora_plus_ratio )
1092
+ all_params .extend (params )
1075
1093
1076
1094
else :
1077
- param_data = {"params" : enumerate_params (self .unet_loras )}
1078
- if unet_lr is not None :
1079
- param_data ["lr" ] = unet_lr
1080
- all_params .append (param_data )
1095
+ params = assemble_params (self .unet_loras , unet_lr , unet_lora_plus_ratio )
1096
+ all_params .extend (params )
1081
1097
1082
1098
return all_params
1083
1099
0 commit comments