Skip to content

Commit 4f203ce

Browse files
committed
Fused backward pass
1 parent 71e2c91 commit 4f203ce

File tree

3 files changed

+142
-6
lines changed

3 files changed

+142
-6
lines changed

library/adafactor_fused.py

+106
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import math
2+
import torch
3+
from transformers import Adafactor
4+
5+
@torch.no_grad()
6+
def adafactor_step_param(self, p, group):
7+
if p.grad is None:
8+
return
9+
grad = p.grad
10+
if grad.dtype in {torch.float16, torch.bfloat16}:
11+
grad = grad.float()
12+
if grad.is_sparse:
13+
raise RuntimeError("Adafactor does not support sparse gradients.")
14+
15+
state = self.state[p]
16+
grad_shape = grad.shape
17+
18+
factored, use_first_moment = Adafactor._get_options(group, grad_shape)
19+
# State Initialization
20+
if len(state) == 0:
21+
state["step"] = 0
22+
23+
if use_first_moment:
24+
# Exponential moving average of gradient values
25+
state["exp_avg"] = torch.zeros_like(grad)
26+
if factored:
27+
state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad)
28+
state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
29+
else:
30+
state["exp_avg_sq"] = torch.zeros_like(grad)
31+
32+
state["RMS"] = 0
33+
else:
34+
if use_first_moment:
35+
state["exp_avg"] = state["exp_avg"].to(grad)
36+
if factored:
37+
state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad)
38+
state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad)
39+
else:
40+
state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)
41+
42+
p_data_fp32 = p
43+
if p.dtype in {torch.float16, torch.bfloat16}:
44+
p_data_fp32 = p_data_fp32.float()
45+
46+
state["step"] += 1
47+
state["RMS"] = Adafactor._rms(p_data_fp32)
48+
lr = Adafactor._get_lr(group, state)
49+
50+
beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
51+
update = (grad ** 2) + group["eps"][0]
52+
if factored:
53+
exp_avg_sq_row = state["exp_avg_sq_row"]
54+
exp_avg_sq_col = state["exp_avg_sq_col"]
55+
56+
exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t))
57+
exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t))
58+
59+
# Approximation of exponential moving average of square of gradient
60+
update = Adafactor._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
61+
update.mul_(grad)
62+
else:
63+
exp_avg_sq = state["exp_avg_sq"]
64+
65+
exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
66+
update = exp_avg_sq.rsqrt().mul_(grad)
67+
68+
update.div_((Adafactor._rms(update) / group["clip_threshold"]).clamp_(min=1.0))
69+
update.mul_(lr)
70+
71+
if use_first_moment:
72+
exp_avg = state["exp_avg"]
73+
exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"]))
74+
update = exp_avg
75+
76+
if group["weight_decay"] != 0:
77+
p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr))
78+
79+
p_data_fp32.add_(-update)
80+
81+
if p.dtype in {torch.float16, torch.bfloat16}:
82+
p.copy_(p_data_fp32)
83+
84+
85+
@torch.no_grad()
86+
def adafactor_step(self, closure=None):
87+
"""
88+
Performs a single optimization step
89+
90+
Arguments:
91+
closure (callable, optional): A closure that reevaluates the model
92+
and returns the loss.
93+
"""
94+
loss = None
95+
if closure is not None:
96+
loss = closure()
97+
98+
for group in self.param_groups:
99+
for p in group["params"]:
100+
adafactor_step_param(self, p, group)
101+
102+
return loss
103+
104+
def patch_adafactor_fused(optimizer: Adafactor):
105+
optimizer.step_param = adafactor_step_param.__get__(optimizer)
106+
optimizer.step = adafactor_step.__get__(optimizer)

library/train_util.py

+13
Original file line numberDiff line numberDiff line change
@@ -2920,6 +2920,11 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
29202920
default=1,
29212921
help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power",
29222922
)
2923+
parser.add_argument(
2924+
"--fused_backward_pass",
2925+
action="store_true",
2926+
help="Combines backward pass and optimizer step to reduce VRAM usage / バックワードパスとオプティマイザステップを組み合わせてVRAMの使用量を削減します。",
2927+
)
29232928

29242929

29252930
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
@@ -3846,6 +3851,14 @@ def get_optimizer(args, trainable_params):
38463851
optimizer_type = "AdamW"
38473852
optimizer_type = optimizer_type.lower()
38483853

3854+
if args.fused_backward_pass:
3855+
assert (
3856+
optimizer_type == "Adafactor".lower()
3857+
), "fused_backward_pass currently only works with optimizer_type Adafactor / fused_backward_passは現在optimizer_type Adafactorでのみ機能します"
3858+
assert (
3859+
args.gradient_accumulation_steps == 1
3860+
), "fused_backward_pass does not work with gradient_accumulation_steps > 1 / fused_backward_passはgradient_accumulation_steps>1では機能しません"
3861+
38493862
# 引数を分解する
38503863
optimizer_kwargs = {}
38513864
if args.optimizer_args is not None and len(args.optimizer_args) > 0:

sdxl_train.py

+23-6
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,20 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
430430
text_encoder2 = accelerator.prepare(text_encoder2)
431431
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
432432

433+
if args.fused_backward_pass:
434+
import library.adafactor_fused
435+
library.adafactor_fused.patch_adafactor_fused(optimizer)
436+
for param_group in optimizer.param_groups:
437+
for parameter in param_group["params"]:
438+
if parameter.requires_grad:
439+
def __grad_hook(tensor: torch.Tensor, param_group=param_group):
440+
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
441+
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
442+
optimizer.step_param(tensor, param_group)
443+
tensor.grad = None
444+
445+
parameter.register_post_accumulate_grad_hook(__grad_hook)
446+
433447
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
434448
if args.cache_text_encoder_outputs:
435449
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
@@ -619,13 +633,16 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
619633
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c)
620634

621635
accelerator.backward(loss)
622-
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
623-
params_to_clip = []
624-
for m in training_models:
625-
params_to_clip.extend(m.parameters())
626-
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
627636

628-
optimizer.step()
637+
if not args.fused_backward_pass:
638+
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
639+
params_to_clip = []
640+
for m in training_models:
641+
params_to_clip.extend(m.parameters())
642+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
643+
644+
optimizer.step()
645+
629646
lr_scheduler.step()
630647
optimizer.zero_grad(set_to_none=True)
631648

0 commit comments

Comments
 (0)