Skip to content

Commit 7fe8150

Browse files
committed
update loraplus on dylora/lofa_fa
1 parent 52e64c6 commit 7fe8150

File tree

3 files changed

+71
-34
lines changed

3 files changed

+71
-34
lines changed

networks/dylora.py

+29-17
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,13 @@
1818
import torch
1919
from torch import nn
2020
from library.utils import setup_logging
21+
2122
setup_logging()
2223
import logging
24+
2325
logger = logging.getLogger(__name__)
2426

27+
2528
class DyLoRAModule(torch.nn.Module):
2629
"""
2730
replaces forward method of the original Linear, instead of replacing the original Linear module.
@@ -195,7 +198,7 @@ def create_network(
195198
conv_alpha = 1.0
196199
else:
197200
conv_alpha = float(conv_alpha)
198-
201+
199202
if unit is not None:
200203
unit = int(unit)
201204
else:
@@ -211,6 +214,16 @@ def create_network(
211214
unit=unit,
212215
varbose=True,
213216
)
217+
218+
loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
219+
loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
220+
loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
221+
loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
222+
loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
223+
loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
224+
if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
225+
network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
226+
214227
return network
215228

216229

@@ -280,6 +293,10 @@ def __init__(
280293
self.alpha = alpha
281294
self.apply_to_conv = apply_to_conv
282295

296+
self.loraplus_lr_ratio = None
297+
self.loraplus_unet_lr_ratio = None
298+
self.loraplus_text_encoder_lr_ratio = None
299+
283300
if modules_dim is not None:
284301
logger.info("create LoRA network from weights")
285302
else:
@@ -320,9 +337,9 @@ def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules
320337
lora = module_class(lora_name, child_module, self.multiplier, dim, alpha, unit)
321338
loras.append(lora)
322339
return loras
323-
340+
324341
text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
325-
342+
326343
self.text_encoder_loras = []
327344
for i, text_encoder in enumerate(text_encoders):
328345
if len(text_encoders) > 1:
@@ -331,7 +348,7 @@ def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules
331348
else:
332349
index = None
333350
logger.info("create LoRA for Text Encoder")
334-
351+
335352
text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
336353
self.text_encoder_loras.extend(text_encoder_loras)
337354

@@ -346,6 +363,11 @@ def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules
346363
self.unet_loras = create_modules(True, unet, target_modules)
347364
logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
348365

366+
def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
367+
self.loraplus_lr_ratio = loraplus_lr_ratio
368+
self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio
369+
self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio
370+
349371
def set_multiplier(self, multiplier):
350372
self.multiplier = multiplier
351373
for lora in self.text_encoder_loras + self.unet_loras:
@@ -407,15 +429,7 @@ def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
407429
"""
408430

409431
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
410-
def prepare_optimizer_params(
411-
self,
412-
text_encoder_lr,
413-
unet_lr,
414-
default_lr,
415-
text_encoder_loraplus_ratio=None,
416-
unet_loraplus_ratio=None,
417-
loraplus_ratio=None
418-
):
432+
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
419433
self.requires_grad_(True)
420434
all_params = []
421435

@@ -452,15 +466,13 @@ def assemble_params(loras, lr, ratio):
452466
params = assemble_params(
453467
self.text_encoder_loras,
454468
text_encoder_lr if text_encoder_lr is not None else default_lr,
455-
text_encoder_loraplus_ratio or loraplus_ratio
469+
self.loraplus_text_encoder_lr_ratio or self.loraplus_ratio,
456470
)
457471
all_params.extend(params)
458472

459473
if self.unet_loras:
460474
params = assemble_params(
461-
self.unet_loras,
462-
default_lr if unet_lr is None else unet_lr,
463-
unet_loraplus_ratio or loraplus_ratio
475+
self.unet_loras, default_lr if unet_lr is None else unet_lr, self.loraplus_unet_lr_ratio or self.loraplus_ratio
464476
)
465477
all_params.extend(params)
466478

networks/lora.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,8 @@ def create_network(
499499
loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
500500
loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
501501
loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
502-
network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
502+
if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
503+
network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
503504

504505
if block_lr_weight is not None:
505506
network.set_block_lr_weight(block_lr_weight)
@@ -855,6 +856,10 @@ def __init__(
855856
self.rank_dropout = rank_dropout
856857
self.module_dropout = module_dropout
857858

859+
self.loraplus_lr_ratio = None
860+
self.loraplus_unet_lr_ratio = None
861+
self.loraplus_text_encoder_lr_ratio = None
862+
858863
if modules_dim is not None:
859864
logger.info(f"create LoRA network from weights")
860865
elif block_dims is not None:

networks/lora_fa.py

+36-16
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515
import torch
1616
import re
1717
from library.utils import setup_logging
18+
1819
setup_logging()
1920
import logging
21+
2022
logger = logging.getLogger(__name__)
2123

2224
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
@@ -504,6 +506,15 @@ def create_network(
504506
if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
505507
network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight)
506508

509+
loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
510+
loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
511+
loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
512+
loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
513+
loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
514+
loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
515+
if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
516+
network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
517+
507518
return network
508519

509520

@@ -529,7 +540,9 @@ def parse_floats(s):
529540
len(block_dims) == num_total_blocks
530541
), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください"
531542
else:
532-
logger.warning(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります")
543+
logger.warning(
544+
f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります"
545+
)
533546
block_dims = [network_dim] * num_total_blocks
534547

535548
if block_alphas is not None:
@@ -803,21 +816,31 @@ def __init__(
803816
self.rank_dropout = rank_dropout
804817
self.module_dropout = module_dropout
805818

819+
self.loraplus_lr_ratio = None
820+
self.loraplus_unet_lr_ratio = None
821+
self.loraplus_text_encoder_lr_ratio = None
822+
806823
if modules_dim is not None:
807824
logger.info(f"create LoRA network from weights")
808825
elif block_dims is not None:
809826
logger.info(f"create LoRA network from block_dims")
810-
logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
827+
logger.info(
828+
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
829+
)
811830
logger.info(f"block_dims: {block_dims}")
812831
logger.info(f"block_alphas: {block_alphas}")
813832
if conv_block_dims is not None:
814833
logger.info(f"conv_block_dims: {conv_block_dims}")
815834
logger.info(f"conv_block_alphas: {conv_block_alphas}")
816835
else:
817836
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
818-
logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
837+
logger.info(
838+
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
839+
)
819840
if self.conv_lora_dim is not None:
820-
logger.info(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
841+
logger.info(
842+
f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
843+
)
821844

822845
# create module instances
823846
def create_modules(
@@ -939,6 +962,11 @@ def create_modules(
939962
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
940963
names.add(lora.lora_name)
941964

965+
def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
966+
self.loraplus_lr_ratio = loraplus_lr_ratio
967+
self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio
968+
self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio
969+
942970
def set_multiplier(self, multiplier):
943971
self.multiplier = multiplier
944972
for lora in self.text_encoder_loras + self.unet_loras:
@@ -1033,15 +1061,7 @@ def get_lr_weight(self, lora: LoRAModule) -> float:
10331061
return lr_weight
10341062

10351063
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
1036-
def prepare_optimizer_params(
1037-
self,
1038-
text_encoder_lr,
1039-
unet_lr,
1040-
default_lr,
1041-
text_encoder_loraplus_ratio=None,
1042-
unet_loraplus_ratio=None,
1043-
loraplus_ratio=None
1044-
):
1064+
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
10451065
self.requires_grad_(True)
10461066
all_params = []
10471067

@@ -1078,7 +1098,7 @@ def assemble_params(loras, lr, ratio):
10781098
params = assemble_params(
10791099
self.text_encoder_loras,
10801100
text_encoder_lr if text_encoder_lr is not None else default_lr,
1081-
text_encoder_loraplus_ratio or loraplus_ratio
1101+
self.loraplus_text_encoder_lr_ratio or self.loraplus_ratio,
10821102
)
10831103
all_params.extend(params)
10841104

@@ -1097,15 +1117,15 @@ def assemble_params(loras, lr, ratio):
10971117
params = assemble_params(
10981118
block_loras,
10991119
(unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]),
1100-
unet_loraplus_ratio or loraplus_ratio
1120+
self.loraplus_unet_lr_ratio or self.loraplus_ratio,
11011121
)
11021122
all_params.extend(params)
11031123

11041124
else:
11051125
params = assemble_params(
11061126
self.unet_loras,
11071127
unet_lr if unet_lr is not None else default_lr,
1108-
unet_loraplus_ratio or loraplus_ratio
1128+
self.loraplus_unet_lr_ratio or self.loraplus_ratio,
11091129
)
11101130
all_params.extend(params)
11111131

0 commit comments

Comments
 (0)