Skip to content

Commit 02298e3

Browse files
authored
Merge pull request kohya-ss#1331 from kohya-ss/lora-plus
Lora plus
2 parents 1ffc0b3 + 4419041 commit 02298e3

File tree

5 files changed

+387
-193
lines changed

5 files changed

+387
-193
lines changed

README.md

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,18 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
154154
- `--fused_optimizer_groups` cannot be used with `--fused_backward_pass`. When using AdaFactor, the memory usage is slightly larger than with Fused optimizer. PyTorch 2.1 or later is required.
155155
- Mechanism: While Fused optimizer performs backward/step for individual parameters within the optimizer, optimizer groups reduce memory usage by grouping parameters and creating multiple optimizers to perform backward/step for each group. Fused optimizer requires implementation on the optimizer side, while optimizer groups are implemented only on the training script side.
156156

157-
- Fixed some bugs when using DeepSpeed. Related [#1247]
157+
- LoRA+ is supported. PR [#1233](https://github.com/kohya-ss/sd-scripts/pull/1233) Thanks to rockerBOO!
158+
- LoRA+ is a method to improve training speed by increasing the learning rate of the UP side (LoRA-B) of LoRA. Specify the multiple. The original paper recommends 16, but adjust as needed. Please see the PR for details.
159+
- Specify `loraplus_lr_ratio` with `--network_args`. Example: `--network_args "loraplus_lr_ratio=16"`
160+
- `loraplus_unet_lr_ratio` and `loraplus_lr_ratio` can be specified separately for U-Net and Text Encoder.
161+
- Example: `--network_args "loraplus_unet_lr_ratio=16" "loraplus_text_encoder_lr_ratio=4"` or `--network_args "loraplus_lr_ratio=16" "loraplus_text_encoder_lr_ratio=4"` etc.
162+
- `network_module` `networks.lora` and `networks.dylora` are available.
163+
164+
- LoRA training in SDXL now supports block-wise learning rates and block-wise dim (rank). PR [#1331](https://github.com/kohya-ss/sd-scripts/pull/1331)
165+
- Specify the learning rate and dim (rank) for each block.
166+
- See [Block-wise learning rates in LoRA](./docs/train_network_README-ja.md#階層別学習率) for details (Japanese only).
167+
168+
- Fixed some bugs when using DeepSpeed. Related [#1247](https://github.com/kohya-ss/sd-scripts/pull/1247)
158169

159170
- SDXL の学習時に Fused optimizer が使えるようになりました。PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) 2kpr 氏に感謝します。
160171
- optimizer の backward pass に step を統合することで学習時のメモリ使用量を大きく削減します。学習結果は未適用時と同一ですが、メモリが潤沢にある場合は速度は遅くなります。
@@ -171,7 +182,18 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
171182
- `--fused_optimizer_groups``--fused_backward_pass` と併用できません。AdaFactor 使用時は Fused optimizer よりも若干メモリ使用量は大きくなります。PyTorch 2.1 以降が必要です。
172183
- 仕組み:Fused optimizer が optimizer 内で個別のパラメータについて backward/step を行っているのに対して、optimizer groups はパラメータをグループ化して複数の optimizer を作成し、それぞれ backward/step を行うことでメモリ使用量を削減します。Fused optimizer は optimizer 側の実装が必要ですが、optimizer groups は学習スクリプト側のみで実装されています。やはり SDXL の学習でのみ効果があります。
173184

174-
- DeepSpeed 使用時のいくつかのバグを修正しました。関連 [#1247]
185+
- LoRA+ がサポートされました。PR [#1233](https://github.com/kohya-ss/sd-scripts/pull/1233) rockerBOO 氏に感謝します。
186+
- LoRA の UP 側(LoRA-B)の学習率を上げることで学習速度の向上を図る手法です。倍数で指定します。元の論文では 16 が推奨されていますが、データセット等にもよりますので、適宜調整してください。PR もあわせてご覧ください。
187+
- `--network_args``loraplus_lr_ratio` を指定します。例:`--network_args "loraplus_lr_ratio=16"`
188+
- `loraplus_unet_lr_ratio``loraplus_lr_ratio` で、U-Net および Text Encoder に個別の値を指定することも可能です。
189+
- 例:`--network_args "loraplus_unet_lr_ratio=16" "loraplus_text_encoder_lr_ratio=4"` または `--network_args "loraplus_lr_ratio=16" "loraplus_text_encoder_lr_ratio=4"` など
190+
- `network_module``networks.lora` および `networks.dylora` で使用可能です。
191+
192+
- SDXL の LoRA で階層別学習率、階層別 dim (rank) をサポートしました。PR [#1331](https://github.com/kohya-ss/sd-scripts/pull/1331)
193+
- ブロックごとに学習率および dim (rank) を指定することができます。
194+
- 詳細は [LoRA の階層別学習率](./docs/train_network_README-ja.md#階層別学習率) をご覧ください。
195+
196+
- DeepSpeed 使用時のいくつかのバグを修正しました。関連 [#1247](https://github.com/kohya-ss/sd-scripts/pull/1247)
175197

176198

177199
### Apr 7, 2024 / 2024-04-07: v0.8.7

docs/train_network_README-ja.md

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -181,16 +181,16 @@ python networks\extract_lora_from_dylora.py --model "foldername/dylora-model.saf
181181

182182
詳細は[PR #355](https://github.com/kohya-ss/sd-scripts/pull/355) をご覧ください。
183183

184-
SDXLは現在サポートしていません。
185-
186184
フルモデルの25個のブロックの重みを指定できます。最初のブロックに該当するLoRAは存在しませんが、階層別LoRA適用等との互換性のために25個としています。またconv2d3x3に拡張しない場合も一部のブロックにはLoRAが存在しませんが、記述を統一するため常に25個の値を指定してください。
187185

186+
SDXL では down/up 9 個、middle 3 個の値を指定してください。
187+
188188
`--network_args` で以下の引数を指定してください。
189189

190190
- `down_lr_weight` : U-Netのdown blocksの学習率の重みを指定します。以下が指定可能です。
191-
- ブロックごとの重み : `"down_lr_weight=0,0,0,0,0,0,1,1,1,1,1,1"` のように12個の数値を指定します
191+
- ブロックごとの重み : `"down_lr_weight=0,0,0,0,0,0,1,1,1,1,1,1"` のように12個(SDXL では 9 個)の数値を指定します
192192
- プリセットからの指定 : `"down_lr_weight=sine"` のように指定します(サインカーブで重みを指定します)。sine, cosine, linear, reverse_linear, zeros が指定可能です。また `"down_lr_weight=cosine+.25"` のように `+数値` を追加すると、指定した数値を加算します(0.25~1.25になります)。
193-
- `mid_lr_weight` : U-Netのmid blockの学習率の重みを指定します。`"down_lr_weight=0.5"` のように数値を一つだけ指定します。
193+
- `mid_lr_weight` : U-Netのmid blockの学習率の重みを指定します。`"down_lr_weight=0.5"` のように数値を一つだけ指定します(SDXL の場合は 3 個)
194194
- `up_lr_weight` : U-Netのup blocksの学習率の重みを指定します。down_lr_weightと同様です。
195195
- 指定を省略した部分は1.0として扱われます。また重みを0にするとそのブロックのLoRAモジュールは作成されません。
196196
- `block_lr_zero_threshold` : 重みがこの値以下の場合、LoRAモジュールを作成しません。デフォルトは0です。
@@ -215,6 +215,9 @@ network_args = [ "block_lr_zero_threshold=0.1", "down_lr_weight=sine+.5", "mid_l
215215

216216
フルモデルの25個のブロックのdim (rank)を指定できます。階層別学習率と同様に一部のブロックにはLoRAが存在しない場合がありますが、常に25個の値を指定してください。
217217

218+
SDXL では 23 個の値を指定してください。一部のブロックにはLoRA が存在しませんが、`sdxl_train.py`[階層別学習率](./train_SDXL-en.md) との互換性のためです。
219+
対応は、`0: time/label embed, 1-9: input blocks 0-8, 10-12: mid blocks 0-2, 13-21: output blocks 0-8, 22: out` です。
220+
218221
`--network_args` で以下の引数を指定してください。
219222

220223
- `block_dims` : 各ブロックのdim (rank)を指定します。`"block_dims=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2"` のように25個の数値を指定します。

networks/dylora.py

Lines changed: 63 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,13 @@
1818
import torch
1919
from torch import nn
2020
from library.utils import setup_logging
21+
2122
setup_logging()
2223
import logging
24+
2325
logger = logging.getLogger(__name__)
2426

27+
2528
class DyLoRAModule(torch.nn.Module):
2629
"""
2730
replaces forward method of the original Linear, instead of replacing the original Linear module.
@@ -195,7 +198,7 @@ def create_network(
195198
conv_alpha = 1.0
196199
else:
197200
conv_alpha = float(conv_alpha)
198-
201+
199202
if unit is not None:
200203
unit = int(unit)
201204
else:
@@ -211,6 +214,16 @@ def create_network(
211214
unit=unit,
212215
varbose=True,
213216
)
217+
218+
loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
219+
loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
220+
loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
221+
loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
222+
loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
223+
loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
224+
if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
225+
network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
226+
214227
return network
215228

216229

@@ -280,6 +293,10 @@ def __init__(
280293
self.alpha = alpha
281294
self.apply_to_conv = apply_to_conv
282295

296+
self.loraplus_lr_ratio = None
297+
self.loraplus_unet_lr_ratio = None
298+
self.loraplus_text_encoder_lr_ratio = None
299+
283300
if modules_dim is not None:
284301
logger.info("create LoRA network from weights")
285302
else:
@@ -320,9 +337,9 @@ def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules
320337
lora = module_class(lora_name, child_module, self.multiplier, dim, alpha, unit)
321338
loras.append(lora)
322339
return loras
323-
340+
324341
text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
325-
342+
326343
self.text_encoder_loras = []
327344
for i, text_encoder in enumerate(text_encoders):
328345
if len(text_encoders) > 1:
@@ -331,7 +348,7 @@ def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules
331348
else:
332349
index = None
333350
logger.info("create LoRA for Text Encoder")
334-
351+
335352
text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
336353
self.text_encoder_loras.extend(text_encoder_loras)
337354

@@ -346,6 +363,11 @@ def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules
346363
self.unet_loras = create_modules(True, unet, target_modules)
347364
logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
348365

366+
def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
367+
self.loraplus_lr_ratio = loraplus_lr_ratio
368+
self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio
369+
self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio
370+
349371
def set_multiplier(self, multiplier):
350372
self.multiplier = multiplier
351373
for lora in self.text_encoder_loras + self.unet_loras:
@@ -406,27 +428,53 @@ def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
406428
logger.info(f"weights are merged")
407429
"""
408430

431+
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
409432
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
410433
self.requires_grad_(True)
411434
all_params = []
412435

413-
def enumerate_params(loras):
414-
params = []
436+
def assemble_params(loras, lr, ratio):
437+
param_groups = {"lora": {}, "plus": {}}
415438
for lora in loras:
416-
params.extend(lora.parameters())
439+
for name, param in lora.named_parameters():
440+
if ratio is not None and "lora_B" in name:
441+
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
442+
else:
443+
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
444+
445+
params = []
446+
for key in param_groups.keys():
447+
param_data = {"params": param_groups[key].values()}
448+
449+
if len(param_data["params"]) == 0:
450+
continue
451+
452+
if lr is not None:
453+
if key == "plus":
454+
param_data["lr"] = lr * ratio
455+
else:
456+
param_data["lr"] = lr
457+
458+
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
459+
continue
460+
461+
params.append(param_data)
462+
417463
return params
418464

419465
if self.text_encoder_loras:
420-
param_data = {"params": enumerate_params(self.text_encoder_loras)}
421-
if text_encoder_lr is not None:
422-
param_data["lr"] = text_encoder_lr
423-
all_params.append(param_data)
466+
params = assemble_params(
467+
self.text_encoder_loras,
468+
text_encoder_lr if text_encoder_lr is not None else default_lr,
469+
self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio,
470+
)
471+
all_params.extend(params)
424472

425473
if self.unet_loras:
426-
param_data = {"params": enumerate_params(self.unet_loras)}
427-
if unet_lr is not None:
428-
param_data["lr"] = unet_lr
429-
all_params.append(param_data)
474+
params = assemble_params(
475+
self.unet_loras, default_lr if unet_lr is None else unet_lr, self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio
476+
)
477+
all_params.extend(params)
430478

431479
return all_params
432480

0 commit comments

Comments
 (0)