Skip to content

Commit b56d5f7

Browse files
committed
add experimental option to fuse params to optimizer groups
1 parent 017b82e commit b56d5f7

File tree

1 file changed

+104
-10
lines changed

1 file changed

+104
-10
lines changed

sdxl_train.py

+104-10
Original file line numberDiff line numberDiff line change
@@ -345,8 +345,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
345345

346346
# calculate number of trainable parameters
347347
n_params = 0
348-
for params in params_to_optimize:
349-
for p in params["params"]:
348+
for group in params_to_optimize:
349+
for p in group["params"]:
350350
n_params += p.numel()
351351

352352
accelerator.print(f"train unet: {train_unet}, text_encoder1: {train_text_encoder1}, text_encoder2: {train_text_encoder2}")
@@ -355,7 +355,44 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
355355

356356
# 学習に必要なクラスを準備する
357357
accelerator.print("prepare optimizer, data loader etc.")
358-
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
358+
359+
if args.fused_optimizer_groups:
360+
# calculate total number of parameters
361+
n_total_params = sum(len(params["params"]) for params in params_to_optimize)
362+
params_per_group = math.ceil(n_total_params / args.fused_optimizer_groups)
363+
364+
# split params into groups
365+
grouped_params = []
366+
param_group = []
367+
param_group_lr = -1
368+
for group in params_to_optimize:
369+
lr = group["lr"]
370+
for p in group["params"]:
371+
if lr != param_group_lr:
372+
if param_group:
373+
grouped_params.append({"params": param_group, "lr": param_group_lr})
374+
param_group = []
375+
param_group_lr = lr
376+
param_group.append(p)
377+
if len(param_group) == params_per_group:
378+
grouped_params.append({"params": param_group, "lr": param_group_lr})
379+
param_group = []
380+
param_group_lr = -1
381+
if param_group:
382+
grouped_params.append({"params": param_group, "lr": param_group_lr})
383+
384+
# prepare optimizers for each group
385+
optimizers = []
386+
for group in grouped_params:
387+
_, _, optimizer = train_util.get_optimizer(args, trainable_params=[group])
388+
optimizers.append(optimizer)
389+
optimizer = optimizers[0] # avoid error in the following code
390+
391+
print(len(grouped_params))
392+
logger.info(f"using {len(optimizers)} optimizers for fused optimizer groups")
393+
394+
else:
395+
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
359396

360397
# dataloaderを準備する
361398
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
@@ -382,7 +419,11 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
382419
train_dataset_group.set_max_train_steps(args.max_train_steps)
383420

384421
# lr schedulerを用意する
385-
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
422+
if args.fused_optimizer_groups:
423+
lr_schedulers = [train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) for optimizer in optimizers]
424+
lr_scheduler = lr_schedulers[0] # avoid error in the following code
425+
else:
426+
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
386427

387428
# 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
388429
if args.full_fp16:
@@ -432,10 +473,12 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
432473

433474
if args.fused_backward_pass:
434475
import library.adafactor_fused
476+
435477
library.adafactor_fused.patch_adafactor_fused(optimizer)
436478
for param_group in optimizer.param_groups:
437479
for parameter in param_group["params"]:
438480
if parameter.requires_grad:
481+
439482
def __grad_hook(tensor: torch.Tensor, param_group=param_group):
440483
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
441484
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
@@ -444,6 +487,36 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group):
444487

445488
parameter.register_post_accumulate_grad_hook(__grad_hook)
446489

490+
elif args.fused_optimizer_groups:
491+
for i in range(1, len(optimizers)):
492+
optimizers[i] = accelerator.prepare(optimizers[i])
493+
lr_schedulers[i] = accelerator.prepare(lr_schedulers[i])
494+
495+
global optimizer_hooked_count
496+
global num_parameters_per_group
497+
global parameter_optimizer_map
498+
optimizer_hooked_count = {}
499+
num_parameters_per_group = [0] * len(optimizers)
500+
parameter_optimizer_map = {}
501+
for opt_idx, optimizer in enumerate(optimizers):
502+
for param_group in optimizer.param_groups:
503+
for parameter in param_group["params"]:
504+
if parameter.requires_grad:
505+
506+
def optimizer_hook(parameter: torch.Tensor):
507+
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
508+
accelerator.clip_grad_norm_(parameter, args.max_grad_norm)
509+
510+
i = parameter_optimizer_map[parameter]
511+
optimizer_hooked_count[i] += 1
512+
if optimizer_hooked_count[i] == num_parameters_per_group[i]:
513+
optimizers[i].step()
514+
optimizers[i].zero_grad()
515+
516+
parameter.register_post_accumulate_grad_hook(optimizer_hook)
517+
parameter_optimizer_map[parameter] = opt_idx
518+
num_parameters_per_group[opt_idx] += 1
519+
447520
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
448521
if args.cache_text_encoder_outputs:
449522
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
@@ -518,6 +591,10 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group):
518591

519592
for step, batch in enumerate(train_dataloader):
520593
current_step.value = global_step
594+
595+
if args.fused_optimizer_groups:
596+
optimizer_hooked_count = {i: 0 for i in range(len(optimizers))}
597+
521598
with accelerator.accumulate(*training_models):
522599
if "latents" in batch and batch["latents"] is not None:
523600
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
@@ -596,7 +673,9 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group):
596673

597674
# Sample noise, sample a random timestep for each image, and add noise to the latents,
598675
# with noise offset and/or multires noise if specified
599-
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
676+
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
677+
args, noise_scheduler, latents
678+
)
600679

601680
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
602681

@@ -614,7 +693,9 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group):
614693
or args.masked_loss
615694
):
616695
# do not mean over batch dimension for snr weight or scale v-pred loss
617-
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
696+
loss = train_util.conditional_loss(
697+
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
698+
)
618699
if args.masked_loss:
619700
loss = apply_masked_loss(loss, batch)
620701
loss = loss.mean([1, 2, 3])
@@ -630,21 +711,28 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group):
630711

631712
loss = loss.mean() # mean over batch dimension
632713
else:
633-
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c)
714+
loss = train_util.conditional_loss(
715+
noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c
716+
)
634717

635718
accelerator.backward(loss)
636719

637-
if not args.fused_backward_pass:
720+
if not (args.fused_backward_pass or args.fused_optimizer_groups):
638721
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
639722
params_to_clip = []
640723
for m in training_models:
641724
params_to_clip.extend(m.parameters())
642725
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
643726

644727
optimizer.step()
728+
elif args.fused_optimizer_groups:
729+
for i in range(1, len(optimizers)):
730+
lr_schedulers[i].step()
645731

646732
lr_scheduler.step()
647-
optimizer.zero_grad(set_to_none=True)
733+
734+
if not (args.fused_backward_pass or args.fused_optimizer_groups):
735+
optimizer.zero_grad(set_to_none=True)
648736

649737
# Checks if the accelerator has performed an optimization step behind the scenes
650738
if accelerator.sync_gradients:
@@ -753,7 +841,7 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group):
753841

754842
accelerator.end_training()
755843

756-
if args.save_state or args.save_state_on_train_end:
844+
if args.save_state or args.save_state_on_train_end:
757845
train_util.save_state_on_train_end(args, accelerator)
758846

759847
del accelerator # この後メモリを使うのでこれは消す
@@ -822,6 +910,12 @@ def setup_parser() -> argparse.ArgumentParser:
822910
help=f"learning rates for each block of U-Net, comma-separated, {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / "
823911
+ f"U-Netの各ブロックの学習率、カンマ区切り、{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値",
824912
)
913+
parser.add_argument(
914+
"--fused_optimizer_groups",
915+
type=int,
916+
default=None,
917+
help="number of optimizers for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizer数",
918+
)
825919
return parser
826920

827921

0 commit comments

Comments
 (0)