Skip to content

Commit c769160

Browse files
committed
Add LoRA-FA for LoRA+
1 parent f99fe28 commit c769160

File tree

1 file changed

+38
-20
lines changed

1 file changed

+38
-20
lines changed

networks/lora_fa.py

+38-20
Original file line numberDiff line numberDiff line change
@@ -1033,22 +1033,43 @@ def get_lr_weight(self, lora: LoRAModule) -> float:
10331033
return lr_weight
10341034

10351035
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
1036-
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
1036+
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr, , unet_lora_plus_ratio=None, text_encoder_lora_plus_ratio=None):
10371037
self.requires_grad_(True)
10381038
all_params = []
10391039

1040-
def enumerate_params(loras: List[LoRAModule]):
1041-
params = []
1040+
def assemble_params(loras: List[LoRAModule], lr, lora_plus_ratio):
1041+
param_groups = {"lora": {}, "plus": {}}
10421042
for lora in loras:
1043-
# params.extend(lora.parameters())
1044-
params.extend(lora.get_trainable_params())
1043+
for name, param in lora.get_trainable_named_params():
1044+
if lora_plus_ratio is not None and "lora_up" in name:
1045+
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
1046+
else:
1047+
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
1048+
1049+
# assigned_param_groups = ""
1050+
# for group in param_groups:
1051+
# assigned_param_groups += f"{group}\n {list(param_groups[group].keys())}\n\n"
1052+
# logger.info(assigned_param_groups)
1053+
1054+
params = []
1055+
for key in param_groups.keys():
1056+
param_data = {"params": param_groups[key].values()}
1057+
if lr is not None:
1058+
if key == "plus":
1059+
param_data["lr"] = lr * lora_plus_ratio
1060+
else:
1061+
param_data["lr"] = lr
1062+
1063+
if ("lr" in param_data) and (param_data["lr"] == 0):
1064+
continue
1065+
1066+
params.append(param_data)
1067+
10451068
return params
10461069

10471070
if self.text_encoder_loras:
1048-
param_data = {"params": enumerate_params(self.text_encoder_loras)}
1049-
if text_encoder_lr is not None:
1050-
param_data["lr"] = text_encoder_lr
1051-
all_params.append(param_data)
1071+
params = assemble_params(self.text_encoder_loras, text_encoder_lr, text_encoder_lora_plus_ratio)
1072+
all_params.extend(params)
10521073

10531074
if self.unet_loras:
10541075
if self.block_lr:
@@ -1062,21 +1083,15 @@ def enumerate_params(loras: List[LoRAModule]):
10621083

10631084
# blockごとにパラメータを設定する
10641085
for idx, block_loras in block_idx_to_lora.items():
1065-
param_data = {"params": enumerate_params(block_loras)}
1066-
10671086
if unet_lr is not None:
1068-
param_data["lr"] = unet_lr * self.get_lr_weight(block_loras[0])
1087+
params = assemble_params(block_loras, unet_lr * self.get_lr_weight(block_loras[0]), unet_lora_plus_ratio)
10691088
elif default_lr is not None:
1070-
param_data["lr"] = default_lr * self.get_lr_weight(block_loras[0])
1071-
if ("lr" in param_data) and (param_data["lr"] == 0):
1072-
continue
1073-
all_params.append(param_data)
1089+
params = assemble_params(block_loras, default_lr * self.get_lr_weight(block_loras[0]), unet_lora_plus_ratio)
1090+
all_params.extend(params)
10741091

10751092
else:
1076-
param_data = {"params": enumerate_params(self.unet_loras)}
1077-
if unet_lr is not None:
1078-
param_data["lr"] = unet_lr
1079-
all_params.append(param_data)
1093+
params = assemble_params(self.unet_loras, unet_lr, unet_lora_plus_ratio)
1094+
all_params.extend(params)
10801095

10811096
return all_params
10821097

@@ -1093,6 +1108,9 @@ def on_epoch_start(self, text_encoder, unet):
10931108
def get_trainable_params(self):
10941109
return self.parameters()
10951110

1111+
def get_trainable_named_params(self):
1112+
return self.named_parameters()
1113+
10961114
def save_weights(self, file, dtype, metadata):
10971115
if metadata is not None and len(metadata) == 0:
10981116
metadata = None

0 commit comments

Comments
 (0)