Skip to content

Commit f99fe28

Browse files
committed
Add LoRA+ support
1 parent f931705 commit f99fe28

File tree

4 files changed

+71
-32
lines changed

4 files changed

+71
-32
lines changed

library/train_util.py

+2
Original file line numberDiff line numberDiff line change
@@ -2789,6 +2789,8 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
27892789
default=1,
27902790
help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power",
27912791
)
2792+
parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio")
2793+
parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio")
27922794

27932795

27942796
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):

networks/dylora.py

+33-12
Original file line numberDiff line numberDiff line change
@@ -406,27 +406,48 @@ def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
406406
logger.info(f"weights are merged")
407407
"""
408408

409-
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
409+
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
410+
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr, unet_lora_plus_ratio=None, text_encoder_lora_plus_ratio=None):
410411
self.requires_grad_(True)
411412
all_params = []
412413

413-
def enumerate_params(loras):
414-
params = []
414+
def assemble_params(loras, lr, lora_plus_ratio):
415+
param_groups = {"lora": {}, "plus": {}}
415416
for lora in loras:
416-
params.extend(lora.parameters())
417+
for name, param in lora.named_parameters():
418+
if lora_plus_ratio is not None and "lora_up" in name:
419+
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
420+
else:
421+
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
422+
423+
# assigned_param_groups = ""
424+
# for group in param_groups:
425+
# assigned_param_groups += f"{group}\n {list(param_groups[group].keys())}\n\n"
426+
# logger.info(assigned_param_groups)
427+
428+
params = []
429+
for key in param_groups.keys():
430+
param_data = {"params": param_groups[key].values()}
431+
if lr is not None:
432+
if key == "plus":
433+
param_data["lr"] = lr * lora_plus_ratio
434+
else:
435+
param_data["lr"] = lr
436+
437+
if ("lr" in param_data) and (param_data["lr"] == 0):
438+
continue
439+
440+
params.append(param_data)
441+
417442
return params
418443

419444
if self.text_encoder_loras:
420-
param_data = {"params": enumerate_params(self.text_encoder_loras)}
421-
if text_encoder_lr is not None:
422-
param_data["lr"] = text_encoder_lr
423-
all_params.append(param_data)
445+
params = assemble_params(self.text_encoder_loras, text_encoder_lr, text_encoder_lora_plus_ratio)
446+
all_params.extend(params)
424447

425448
if self.unet_loras:
426-
param_data = {"params": enumerate_params(self.unet_loras)}
427-
if unet_lr is not None:
428-
param_data["lr"] = unet_lr
429-
all_params.append(param_data)
449+
params = assemble_params(self.unet_loras, unet_lr, unet_lora_plus_ratio)
450+
all_params.extend(params)
430451

431452
return all_params
432453

networks/lora.py

+35-19
Original file line numberDiff line numberDiff line change
@@ -1035,21 +1035,43 @@ def get_lr_weight(self, lora: LoRAModule) -> float:
10351035
return lr_weight
10361036

10371037
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
1038-
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
1038+
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr, unet_lora_plus_ratio=None, text_encoder_lora_plus_ratio=None):
10391039
self.requires_grad_(True)
10401040
all_params = []
10411041

1042-
def enumerate_params(loras):
1043-
params = []
1042+
def assemble_params(loras, lr, lora_plus_ratio):
1043+
param_groups = {"lora": {}, "plus": {}}
10441044
for lora in loras:
1045-
params.extend(lora.parameters())
1045+
for name, param in lora.named_parameters():
1046+
if lora_plus_ratio is not None and "lora_up" in name:
1047+
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
1048+
else:
1049+
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
1050+
1051+
# assigned_param_groups = ""
1052+
# for group in param_groups:
1053+
# assigned_param_groups += f"{group}\n {list(param_groups[group].keys())}\n\n"
1054+
# logger.info(assigned_param_groups)
1055+
1056+
params = []
1057+
for key in param_groups.keys():
1058+
param_data = {"params": param_groups[key].values()}
1059+
if lr is not None:
1060+
if key == "plus":
1061+
param_data["lr"] = lr * lora_plus_ratio
1062+
else:
1063+
param_data["lr"] = lr
1064+
1065+
if ("lr" in param_data) and (param_data["lr"] == 0):
1066+
continue
1067+
1068+
params.append(param_data)
1069+
10461070
return params
10471071

10481072
if self.text_encoder_loras:
1049-
param_data = {"params": enumerate_params(self.text_encoder_loras)}
1050-
if text_encoder_lr is not None:
1051-
param_data["lr"] = text_encoder_lr
1052-
all_params.append(param_data)
1073+
params = assemble_params(self.text_encoder_loras, text_encoder_lr, text_encoder_lora_plus_ratio)
1074+
all_params.extend(params)
10531075

10541076
if self.unet_loras:
10551077
if self.block_lr:
@@ -1063,21 +1085,15 @@ def enumerate_params(loras):
10631085

10641086
# blockごとにパラメータを設定する
10651087
for idx, block_loras in block_idx_to_lora.items():
1066-
param_data = {"params": enumerate_params(block_loras)}
1067-
10681088
if unet_lr is not None:
1069-
param_data["lr"] = unet_lr * self.get_lr_weight(block_loras[0])
1089+
params = assemble_params(block_loras, unet_lr * self.get_lr_weight(block_loras[0]), unet_lora_plus_ratio)
10701090
elif default_lr is not None:
1071-
param_data["lr"] = default_lr * self.get_lr_weight(block_loras[0])
1072-
if ("lr" in param_data) and (param_data["lr"] == 0):
1073-
continue
1074-
all_params.append(param_data)
1091+
params = assemble_params(block_loras, default_lr * self.get_lr_weight(block_loras[0]), unet_lora_plus_ratio)
1092+
all_params.extend(params)
10751093

10761094
else:
1077-
param_data = {"params": enumerate_params(self.unet_loras)}
1078-
if unet_lr is not None:
1079-
param_data["lr"] = unet_lr
1080-
all_params.append(param_data)
1095+
params = assemble_params(self.unet_loras, unet_lr, unet_lora_plus_ratio)
1096+
all_params.extend(params)
10811097

10821098
return all_params
10831099

train_network.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ def train(self, args):
339339

340340
# 後方互換性を確保するよ
341341
try:
342-
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate)
342+
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate, args.loraplus_text_encoder_lr_ratio, args.loraplus_unet_lr_ratio)
343343
except TypeError:
344344
accelerator.print(
345345
"Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)"

0 commit comments

Comments
 (0)