Skip to content

Commit 969f82a

Browse files
committed
move loraplus args from args to network_args, simplify log lr desc
1 parent 834445a commit 969f82a

File tree

3 files changed

+84
-91
lines changed

3 files changed

+84
-91
lines changed

library/train_util.py

-3
Original file line numberDiff line numberDiff line change
@@ -2920,9 +2920,6 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
29202920
default=1,
29212921
help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power",
29222922
)
2923-
parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio")
2924-
parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio")
2925-
parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio")
29262923

29272924

29282925
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):

networks/lora.py

+40-18
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,14 @@ def create_network(
490490
varbose=True,
491491
)
492492

493+
loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
494+
loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
495+
loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
496+
loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
497+
loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
498+
loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
499+
network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
500+
493501
if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
494502
network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight)
495503

@@ -1033,18 +1041,27 @@ def get_lr_weight(self, lora: LoRAModule) -> float:
10331041

10341042
return lr_weight
10351043

1044+
def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
1045+
self.loraplus_lr_ratio = loraplus_lr_ratio
1046+
self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio
1047+
self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio
1048+
10361049
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
1037-
def prepare_optimizer_params(
1038-
self,
1039-
text_encoder_lr,
1040-
unet_lr,
1041-
default_lr,
1042-
text_encoder_loraplus_ratio=None,
1043-
unet_loraplus_ratio=None,
1044-
loraplus_ratio=None
1045-
):
1050+
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
1051+
# TODO warn if optimizer is not compatible with LoRA+ (but it will cause error so we don't need to check it here?)
1052+
# if (
1053+
# self.loraplus_lr_ratio is not None
1054+
# or self.loraplus_text_encoder_lr_ratio is not None
1055+
# or self.loraplus_unet_lr_ratio is not None
1056+
# ):
1057+
# assert (
1058+
# optimizer_type.lower() != "prodigy" and "dadapt" not in optimizer_type.lower()
1059+
# ), "LoRA+ and Prodigy/DAdaptation is not supported / LoRA+とProdigy/DAdaptationの組み合わせはサポートされていません"
1060+
10461061
self.requires_grad_(True)
1062+
10471063
all_params = []
1064+
lr_descriptions = []
10481065

10491066
def assemble_params(loras, lr, ratio):
10501067
param_groups = {"lora": {}, "plus": {}}
@@ -1056,6 +1073,7 @@ def assemble_params(loras, lr, ratio):
10561073
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
10571074

10581075
params = []
1076+
descriptions = []
10591077
for key in param_groups.keys():
10601078
param_data = {"params": param_groups[key].values()}
10611079

@@ -1069,20 +1087,22 @@ def assemble_params(loras, lr, ratio):
10691087
param_data["lr"] = lr
10701088

10711089
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
1072-
print("NO LR skipping!")
1090+
logger.info("NO LR skipping!")
10731091
continue
10741092

10751093
params.append(param_data)
1094+
descriptions.append("plus" if key == "plus" else "")
10761095

1077-
return params
1096+
return params, descriptions
10781097

10791098
if self.text_encoder_loras:
1080-
params = assemble_params(
1099+
params, descriptions = assemble_params(
10811100
self.text_encoder_loras,
10821101
text_encoder_lr if text_encoder_lr is not None else default_lr,
1083-
text_encoder_loraplus_ratio or loraplus_ratio
1102+
self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio,
10841103
)
10851104
all_params.extend(params)
1105+
lr_descriptions.extend(["textencoder" + (" " + d if d else "") for d in descriptions])
10861106

10871107
if self.unet_loras:
10881108
if self.block_lr:
@@ -1096,22 +1116,24 @@ def assemble_params(loras, lr, ratio):
10961116

10971117
# blockごとにパラメータを設定する
10981118
for idx, block_loras in block_idx_to_lora.items():
1099-
params = assemble_params(
1119+
params, descriptions = assemble_params(
11001120
block_loras,
11011121
(unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]),
1102-
unet_loraplus_ratio or loraplus_ratio
1122+
self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio,
11031123
)
11041124
all_params.extend(params)
1125+
lr_descriptions.extend([f"unet_block{idx}" + (" " + d if d else "") for d in descriptions])
11051126

11061127
else:
1107-
params = assemble_params(
1128+
params, descriptions = assemble_params(
11081129
self.unet_loras,
11091130
unet_lr if unet_lr is not None else default_lr,
1110-
unet_loraplus_ratio or loraplus_ratio
1131+
self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio,
11111132
)
11121133
all_params.extend(params)
1134+
lr_descriptions.extend(["unet" + (" " + d if d else "") for d in descriptions])
11131135

1114-
return all_params
1136+
return all_params, lr_descriptions
11151137

11161138
def enable_gradient_checkpointing(self):
11171139
# not supported

train_network.py

+44-70
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,15 @@ def __init__(self):
5353

5454
# TODO 他のスクリプトと共通化する
5555
def generate_step_logs(
56-
self, args: argparse.Namespace, current_loss, avr_loss, lr_scheduler, keys_scaled=None, mean_norm=None, maximum_norm=None
56+
self,
57+
args: argparse.Namespace,
58+
current_loss,
59+
avr_loss,
60+
lr_scheduler,
61+
lr_descriptions,
62+
keys_scaled=None,
63+
mean_norm=None,
64+
maximum_norm=None,
5765
):
5866
logs = {"loss/current": current_loss, "loss/average": avr_loss}
5967

@@ -63,68 +71,25 @@ def generate_step_logs(
6371
logs["max_norm/max_key_norm"] = maximum_norm
6472

6573
lrs = lr_scheduler.get_last_lr()
66-
67-
if len(lrs) > 4:
68-
idx = 0
69-
if not args.network_train_unet_only:
70-
logs["lr/textencoder"] = float(lrs[0])
71-
idx = 1
72-
73-
for i in range(idx, len(lrs)):
74-
lora_plus = ""
75-
group_id = i
76-
77-
if args.loraplus_lr_ratio is not None or args.loraplus_unet_lr_ratio is not None:
78-
lora_plus = '_lora+' if i % 2 == 1 else ''
79-
group_id = int((i / 2) + (i % 2 + 0.5))
80-
81-
logs[f"lr/group{group_id}{lora_plus}"] = float(lrs[i])
82-
if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower():
83-
logs[f"lr/d*lr/group{group_id}{lora_plus}"] = (
84-
lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
85-
)
86-
87-
else:
88-
if args.network_train_text_encoder_only:
89-
if args.loraplus_lr_ratio is not None or args.loraplus_text_encoder_lr_ratio is not None:
90-
logs["lr/textencoder"] = float(lrs[0])
91-
logs["lr/textencoder_lora+"] = float(lrs[1])
92-
else:
93-
logs["lr/textencoder"] = float(lrs[0])
94-
95-
elif args.network_train_unet_only:
96-
if args.loraplus_lr_ratio is not None or args.loraplus_unet_lr_ratio is not None:
97-
logs["lr/unet"] = float(lrs[0])
98-
logs["lr/unet_lora+"] = float(lrs[1])
99-
else:
100-
logs["lr/unet"] = float(lrs[0])
74+
for i, lr in enumerate(lrs):
75+
if lr_descriptions is not None:
76+
lr_desc = lr_descriptions[i]
10177
else:
102-
if len(lrs) == 2:
103-
if args.loraplus_text_encoder_lr_ratio is not None and args.loraplus_unet_lr_ratio is None:
104-
logs["lr/textencoder"] = float(lrs[0])
105-
logs["lr/textencoder_lora+"] = float(lrs[1])
106-
elif args.loraplus_unet_lr_ratio is not None and args.loraplus_text_encoder_lr_ratio is None:
107-
logs["lr/unet"] = float(lrs[0])
108-
logs["lr/unet_lora+"] = float(lrs[1])
109-
elif args.loraplus_unet_lr_ratio is None and args.loraplus_text_encoder_lr_ratio is None and args.loraplus_lr_ratio is not None:
110-
logs["lr/all"] = float(lrs[0])
111-
logs["lr/all_lora+"] = float(lrs[1])
112-
else:
113-
logs["lr/textencoder"] = float(lrs[0])
114-
logs["lr/unet"] = float(lrs[-1])
115-
elif len(lrs) == 4:
116-
logs["lr/textencoder"] = float(lrs[0])
117-
logs["lr/textencoder_lora+"] = float(lrs[1])
118-
logs["lr/unet"] = float(lrs[2])
119-
logs["lr/unet_lora+"] = float(lrs[3])
78+
idx = i - (0 if args.network_train_unet_only else -1)
79+
if idx == -1:
80+
lr_desc = "textencoder"
12081
else:
121-
logs["lr/all"] = float(lrs[0])
82+
if len(lrs) > 2:
83+
lr_desc = f"group{idx}"
84+
else:
85+
lr_desc = "unet"
86+
87+
logs[f"lr/{lr_desc}"] = lr
12288

123-
if (
124-
args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower()
125-
): # tracking d*lr value of unet.
126-
logs["lr/d*lr"] = (
127-
lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"]
89+
if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower():
90+
# tracking d*lr value
91+
logs[f"lr/d*lr/{lr_desc}"] = (
92+
lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
12893
)
12994

13095
return logs
@@ -358,6 +323,7 @@ def train(self, args):
358323
network.apply_to(text_encoder, unet, train_text_encoder, train_unet)
359324

360325
if args.network_weights is not None:
326+
# FIXME consider alpha of weights
361327
info = network.load_weights(args.network_weights)
362328
accelerator.print(f"load network weights from {args.network_weights}: {info}")
363329

@@ -373,20 +339,23 @@ def train(self, args):
373339

374340
# 後方互換性を確保するよ
375341
try:
376-
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate, args.loraplus_text_encoder_lr_ratio, args.loraplus_unet_lr_ratio, args.loraplus_lr_ratio)
342+
results = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate)
343+
if type(results) is tuple:
344+
trainable_params = results[0]
345+
lr_descriptions = results[1]
346+
else:
347+
trainable_params = results
348+
lr_descriptions = None
377349
except TypeError:
378-
accelerator.print(
379-
"Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)"
380-
)
350+
# accelerator.print(
351+
# "Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)"
352+
# )
381353
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
354+
lr_descriptions = None
355+
print(lr_descriptions)
382356

383357
optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
384358

385-
if args.loraplus_lr_ratio is not None or args.loraplus_text_encoder_lr_ratio is not None or args.loraplus_unet_lr_ratio is not None:
386-
assert (
387-
(optimizer_name != "Prodigy" and "DAdapt" not in optimizer_name)
388-
), "LoRA+ and Prodigy/DAdaptation is not supported"
389-
390359
# dataloaderを準備する
391360
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
392361
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
@@ -992,7 +961,9 @@ def remove_model(old_ckpt_name):
992961
progress_bar.set_postfix(**{**max_mean_logs, **logs})
993962

994963
if args.logging_dir is not None:
995-
logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm)
964+
logs = self.generate_step_logs(
965+
args, current_loss, avr_loss, lr_scheduler, lr_descriptions, keys_scaled, mean_norm, maximum_norm
966+
)
996967
accelerator.log(logs, step=global_step)
997968

998969
if global_step >= args.max_train_steps:
@@ -1143,6 +1114,9 @@ def setup_parser() -> argparse.ArgumentParser:
11431114
action="store_true",
11441115
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
11451116
)
1117+
# parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio")
1118+
# parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio")
1119+
# parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio")
11461120
return parser
11471121

11481122

0 commit comments

Comments
 (0)