Skip to content

Commit 68467bd

Browse files
committed
Fix unset or invalid LR from making a param_group
1 parent 75833e8 commit 68467bd

File tree

3 files changed

+7
-6
lines changed

3 files changed

+7
-6
lines changed

networks/dylora.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -412,8 +412,8 @@ def prepare_optimizer_params(
412412
text_encoder_lr,
413413
unet_lr,
414414
default_lr,
415-
unet_loraplus_ratio=None,
416415
text_encoder_loraplus_ratio=None,
416+
unet_loraplus_ratio=None,
417417
loraplus_ratio=None
418418
):
419419
self.requires_grad_(True)
@@ -441,7 +441,7 @@ def assemble_params(loras, lr, ratio):
441441
else:
442442
param_data["lr"] = lr
443443

444-
if ("lr" in param_data) and (param_data["lr"] == 0):
444+
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
445445
continue
446446

447447
params.append(param_data)

networks/lora.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1040,8 +1040,8 @@ def prepare_optimizer_params(
10401040
text_encoder_lr,
10411041
unet_lr,
10421042
default_lr,
1043-
unet_loraplus_ratio=None,
10441043
text_encoder_loraplus_ratio=None,
1044+
unet_loraplus_ratio=None,
10451045
loraplus_ratio=None
10461046
):
10471047
self.requires_grad_(True)
@@ -1069,7 +1069,8 @@ def assemble_params(loras, lr, ratio):
10691069
else:
10701070
param_data["lr"] = lr
10711071

1072-
if ("lr" in param_data) and (param_data["lr"] == 0):
1072+
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
1073+
print("NO LR skipping!")
10731074
continue
10741075

10751076
params.append(param_data)

networks/lora_fa.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1038,8 +1038,8 @@ def prepare_optimizer_params(
10381038
text_encoder_lr,
10391039
unet_lr,
10401040
default_lr,
1041-
unet_loraplus_ratio=None,
10421041
text_encoder_loraplus_ratio=None,
1042+
unet_loraplus_ratio=None,
10431043
loraplus_ratio=None
10441044
):
10451045
self.requires_grad_(True)
@@ -1067,7 +1067,7 @@ def assemble_params(loras, lr, ratio):
10671067
else:
10681068
param_data["lr"] = lr
10691069

1070-
if ("lr" in param_data) and (param_data["lr"] == 0):
1070+
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
10711071
continue
10721072

10731073
params.append(param_data)

0 commit comments

Comments
 (0)