Skip to content

Commit 75833e8

Browse files
committed
Fix default LR, Add overall LoRA+ ratio, Add log
`--loraplus_ratio` added for both TE and UNet Add log for lora+
1 parent 1933ab4 commit 75833e8

File tree

5 files changed

+101
-60
lines changed

5 files changed

+101
-60
lines changed

library/train_util.py

+1
Original file line numberDiff line numberDiff line change
@@ -2789,6 +2789,7 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
27892789
default=1,
27902790
help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power",
27912791
)
2792+
parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio")
27922793
parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio")
27932794
parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio")
27942795

networks/dylora.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -412,32 +412,32 @@ def prepare_optimizer_params(
412412
text_encoder_lr,
413413
unet_lr,
414414
default_lr,
415-
unet_lora_plus_ratio=None,
416-
text_encoder_lora_plus_ratio=None
415+
unet_loraplus_ratio=None,
416+
text_encoder_loraplus_ratio=None,
417+
loraplus_ratio=None
417418
):
418419
self.requires_grad_(True)
419420
all_params = []
420421

421-
def assemble_params(loras, lr, lora_plus_ratio):
422+
def assemble_params(loras, lr, ratio):
422423
param_groups = {"lora": {}, "plus": {}}
423424
for lora in loras:
424425
for name, param in lora.named_parameters():
425-
if lora_plus_ratio is not None and "lora_up" in name:
426+
if ratio is not None and "lora_B" in name:
426427
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
427428
else:
428429
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
429430

430-
# assigned_param_groups = ""
431-
# for group in param_groups:
432-
# assigned_param_groups += f"{group}\n {list(param_groups[group].keys())}\n\n"
433-
# logger.info(assigned_param_groups)
434-
435431
params = []
436432
for key in param_groups.keys():
437433
param_data = {"params": param_groups[key].values()}
434+
435+
if len(param_data["params"]) == 0:
436+
continue
437+
438438
if lr is not None:
439439
if key == "plus":
440-
param_data["lr"] = lr * lora_plus_ratio
440+
param_data["lr"] = lr * ratio
441441
else:
442442
param_data["lr"] = lr
443443

@@ -452,15 +452,15 @@ def assemble_params(loras, lr, lora_plus_ratio):
452452
params = assemble_params(
453453
self.text_encoder_loras,
454454
text_encoder_lr if text_encoder_lr is not None else default_lr,
455-
text_encoder_lora_plus_ratio
455+
text_encoder_loraplus_ratio or loraplus_ratio
456456
)
457457
all_params.extend(params)
458458

459459
if self.unet_loras:
460460
params = assemble_params(
461461
self.unet_loras,
462462
default_lr if unet_lr is None else unet_lr,
463-
unet_lora_plus_ratio
463+
unet_loraplus_ratio or loraplus_ratio
464464
)
465465
all_params.extend(params)
466466

networks/lora.py

+14-14
Original file line numberDiff line numberDiff line change
@@ -1040,32 +1040,32 @@ def prepare_optimizer_params(
10401040
text_encoder_lr,
10411041
unet_lr,
10421042
default_lr,
1043-
unet_lora_plus_ratio=None,
1044-
text_encoder_lora_plus_ratio=None
1043+
unet_loraplus_ratio=None,
1044+
text_encoder_loraplus_ratio=None,
1045+
loraplus_ratio=None
10451046
):
10461047
self.requires_grad_(True)
10471048
all_params = []
10481049

1049-
def assemble_params(loras, lr, lora_plus_ratio):
1050+
def assemble_params(loras, lr, ratio):
10501051
param_groups = {"lora": {}, "plus": {}}
10511052
for lora in loras:
10521053
for name, param in lora.named_parameters():
1053-
if lora_plus_ratio is not None and "lora_up" in name:
1054+
if ratio is not None and "lora_up" in name:
10541055
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
10551056
else:
10561057
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
10571058

1058-
# assigned_param_groups = ""
1059-
# for group in param_groups:
1060-
# assigned_param_groups += f"{group}\n {list(param_groups[group].keys())}\n\n"
1061-
# logger.info(assigned_param_groups)
1062-
10631059
params = []
10641060
for key in param_groups.keys():
10651061
param_data = {"params": param_groups[key].values()}
1062+
1063+
if len(param_data["params"]) == 0:
1064+
continue
1065+
10661066
if lr is not None:
10671067
if key == "plus":
1068-
param_data["lr"] = lr * lora_plus_ratio
1068+
param_data["lr"] = lr * ratio
10691069
else:
10701070
param_data["lr"] = lr
10711071

@@ -1080,7 +1080,7 @@ def assemble_params(loras, lr, lora_plus_ratio):
10801080
params = assemble_params(
10811081
self.text_encoder_loras,
10821082
text_encoder_lr if text_encoder_lr is not None else default_lr,
1083-
text_encoder_lora_plus_ratio
1083+
text_encoder_loraplus_ratio or loraplus_ratio
10841084
)
10851085
all_params.extend(params)
10861086

@@ -1099,15 +1099,15 @@ def assemble_params(loras, lr, lora_plus_ratio):
10991099
params = assemble_params(
11001100
block_loras,
11011101
(unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]),
1102-
unet_lora_plus_ratio
1102+
unet_loraplus_ratio or loraplus_ratio
11031103
)
11041104
all_params.extend(params)
11051105

11061106
else:
11071107
params = assemble_params(
11081108
self.unet_loras,
1109-
default_lr if unet_lr is None else unet_lr,
1110-
unet_lora_plus_ratio
1109+
unet_lr if unet_lr is not None else default_lr,
1110+
unet_loraplus_ratio or loraplus_ratio
11111111
)
11121112
all_params.extend(params)
11131113

networks/lora_fa.py

+15-15
Original file line numberDiff line numberDiff line change
@@ -1038,32 +1038,32 @@ def prepare_optimizer_params(
10381038
text_encoder_lr,
10391039
unet_lr,
10401040
default_lr,
1041-
unet_lora_plus_ratio=None,
1042-
text_encoder_lora_plus_ratio=None
1041+
unet_loraplus_ratio=None,
1042+
text_encoder_loraplus_ratio=None,
1043+
loraplus_ratio=None
10431044
):
10441045
self.requires_grad_(True)
10451046
all_params = []
10461047

1047-
def assemble_params(loras: List[LoRAModule], lr, lora_plus_ratio):
1048+
def assemble_params(loras, lr, ratio):
10481049
param_groups = {"lora": {}, "plus": {}}
10491050
for lora in loras:
1050-
for name, param in lora.get_trainable_named_params():
1051-
if lora_plus_ratio is not None and "lora_up" in name:
1051+
for name, param in lora.named_parameters():
1052+
if ratio is not None and "lora_up" in name:
10521053
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
10531054
else:
10541055
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
10551056

1056-
# assigned_param_groups = ""
1057-
# for group in param_groups:
1058-
# assigned_param_groups += f"{group}\n {list(param_groups[group].keys())}\n\n"
1059-
# logger.info(assigned_param_groups)
1060-
10611057
params = []
10621058
for key in param_groups.keys():
10631059
param_data = {"params": param_groups[key].values()}
1060+
1061+
if len(param_data["params"]) == 0:
1062+
continue
1063+
10641064
if lr is not None:
10651065
if key == "plus":
1066-
param_data["lr"] = lr * lora_plus_ratio
1066+
param_data["lr"] = lr * ratio
10671067
else:
10681068
param_data["lr"] = lr
10691069

@@ -1078,7 +1078,7 @@ def assemble_params(loras: List[LoRAModule], lr, lora_plus_ratio):
10781078
params = assemble_params(
10791079
self.text_encoder_loras,
10801080
text_encoder_lr if text_encoder_lr is not None else default_lr,
1081-
text_encoder_lora_plus_ratio
1081+
text_encoder_loraplus_ratio or loraplus_ratio
10821082
)
10831083
all_params.extend(params)
10841084

@@ -1097,15 +1097,15 @@ def assemble_params(loras: List[LoRAModule], lr, lora_plus_ratio):
10971097
params = assemble_params(
10981098
block_loras,
10991099
(unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]),
1100-
unet_lora_plus_ratio
1100+
unet_loraplus_ratio or loraplus_ratio
11011101
)
11021102
all_params.extend(params)
11031103

11041104
else:
11051105
params = assemble_params(
11061106
self.unet_loras,
1107-
default_lr if unet_lr is None else unet_lr,
1108-
unet_lora_plus_ratio
1107+
unet_lr if unet_lr is not None else default_lr,
1108+
unet_loraplus_ratio or loraplus_ratio
11091109
)
11101110
all_params.extend(params)
11111111

train_network.py

+59-19
Original file line numberDiff line numberDiff line change
@@ -66,34 +66,69 @@ def generate_step_logs(
6666

6767
lrs = lr_scheduler.get_last_lr()
6868

69-
if args.network_train_text_encoder_only or len(lrs) <= 2: # not block lr (or single block)
70-
if args.network_train_unet_only:
71-
logs["lr/unet"] = float(lrs[0])
72-
elif args.network_train_text_encoder_only:
73-
logs["lr/textencoder"] = float(lrs[0])
74-
else:
75-
logs["lr/textencoder"] = float(lrs[0])
76-
logs["lr/unet"] = float(lrs[-1]) # may be same to textencoder
77-
78-
if (
79-
args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower()
80-
): # tracking d*lr value of unet.
81-
logs["lr/d*lr"] = (
82-
lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"]
83-
)
84-
else:
69+
if len(lrs) > 4:
8570
idx = 0
8671
if not args.network_train_unet_only:
8772
logs["lr/textencoder"] = float(lrs[0])
8873
idx = 1
8974

9075
for i in range(idx, len(lrs)):
91-
logs[f"lr/group{i}"] = float(lrs[i])
76+
lora_plus = ""
77+
group_id = i
78+
79+
if args.loraplus_lr_ratio is not None or args.loraplus_unet_lr_ratio is not None:
80+
lora_plus = '_lora+' if i % 2 == 1 else ''
81+
group_id = int((i / 2) + (i % 2 + 0.5))
82+
83+
logs[f"lr/group{group_id}{lora_plus}"] = float(lrs[i])
9284
if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower():
93-
logs[f"lr/d*lr/group{i}"] = (
85+
logs[f"lr/d*lr/group{group_id}{lora_plus}"] = (
9486
lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
9587
)
9688

89+
else:
90+
if args.network_train_text_encoder_only:
91+
if args.loraplus_lr_ratio is not None or args.loraplus_text_encoder_lr_ratio is not None:
92+
logs["lr/textencoder"] = float(lrs[0])
93+
logs["lr/textencoder_lora+"] = float(lrs[1])
94+
else:
95+
logs["lr/textencoder"] = float(lrs[0])
96+
97+
elif args.network_train_unet_only:
98+
if args.loraplus_lr_ratio is not None or args.loraplus_unet_lr_ratio is not None:
99+
logs["lr/unet"] = float(lrs[0])
100+
logs["lr/unet_lora+"] = float(lrs[1])
101+
else:
102+
logs["lr/unet"] = float(lrs[0])
103+
else:
104+
if len(lrs) == 2:
105+
if args.loraplus_text_encoder_lr_ratio is not None and args.loraplus_unet_lr_ratio is None:
106+
logs["lr/textencoder"] = float(lrs[0])
107+
logs["lr/textencoder_lora+"] = float(lrs[1])
108+
elif args.loraplus_unet_lr_ratio is not None and args.loraplus_text_encoder_lr_ratio is None:
109+
logs["lr/unet"] = float(lrs[0])
110+
logs["lr/unet_lora+"] = float(lrs[1])
111+
elif args.loraplus_unet_lr_ratio is None and args.loraplus_text_encoder_lr_ratio is None and args.loraplus_lr_ratio is not None:
112+
logs["lr/all"] = float(lrs[0])
113+
logs["lr/all_lora+"] = float(lrs[1])
114+
else:
115+
logs["lr/textencoder"] = float(lrs[0])
116+
logs["lr/unet"] = float(lrs[-1])
117+
elif len(lrs) == 4:
118+
logs["lr/textencoder"] = float(lrs[0])
119+
logs["lr/textencoder_lora+"] = float(lrs[1])
120+
logs["lr/unet"] = float(lrs[2])
121+
logs["lr/unet_lora+"] = float(lrs[3])
122+
else:
123+
logs["lr/all"] = float(lrs[0])
124+
125+
if (
126+
args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower()
127+
): # tracking d*lr value of unet.
128+
logs["lr/d*lr"] = (
129+
lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"]
130+
)
131+
97132
return logs
98133

99134
def assert_extra_args(self, args, train_dataset_group):
@@ -339,7 +374,7 @@ def train(self, args):
339374

340375
# 後方互換性を確保するよ
341376
try:
342-
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate, args.loraplus_text_encoder_lr_ratio, args.loraplus_unet_lr_ratio)
377+
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate, args.loraplus_text_encoder_lr_ratio, args.loraplus_unet_lr_ratio, args.loraplus_lr_ratio)
343378
except TypeError:
344379
accelerator.print(
345380
"Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)"
@@ -348,6 +383,11 @@ def train(self, args):
348383

349384
optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
350385

386+
if args.loraplus_lr_ratio is not None or args.loraplus_text_encoder_lr_ratio is not None or args.loraplus_unet_lr_ratio is not None:
387+
assert (
388+
(optimizer_name != "Prodigy" and "DAdapt" not in optimizer_name)
389+
), "LoRA+ and Prodigy/DAdaptation is not supported"
390+
351391
# dataloaderを準備する
352392
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
353393
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers

0 commit comments

Comments
 (0)