Skip to content

Commit 607e041

Browse files
committed
chore: Refactor optimizer group
1 parent b56d5f7 commit 607e041

File tree

1 file changed

+26
-11
lines changed

1 file changed

+26
-11
lines changed

Diff for: sdxl_train.py

+26-11
Original file line numberDiff line numberDiff line change
@@ -357,27 +357,37 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
357357
accelerator.print("prepare optimizer, data loader etc.")
358358

359359
if args.fused_optimizer_groups:
360+
# fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html
361+
# Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each group of parameters.
362+
# This balances memory usage and management complexity.
363+
360364
# calculate total number of parameters
361365
n_total_params = sum(len(params["params"]) for params in params_to_optimize)
362366
params_per_group = math.ceil(n_total_params / args.fused_optimizer_groups)
363367

364-
# split params into groups
368+
# split params into groups, keeping the learning rate the same for all params in a group
369+
# this will increase the number of groups if the learning rate is different for different params (e.g. U-Net and text encoders)
365370
grouped_params = []
366371
param_group = []
367372
param_group_lr = -1
368373
for group in params_to_optimize:
369374
lr = group["lr"]
370375
for p in group["params"]:
376+
# if the learning rate is different for different params, start a new group
371377
if lr != param_group_lr:
372378
if param_group:
373379
grouped_params.append({"params": param_group, "lr": param_group_lr})
374380
param_group = []
375381
param_group_lr = lr
382+
376383
param_group.append(p)
384+
385+
# if the group has enough parameters, start a new group
377386
if len(param_group) == params_per_group:
378387
grouped_params.append({"params": param_group, "lr": param_group_lr})
379388
param_group = []
380389
param_group_lr = -1
390+
381391
if param_group:
382392
grouped_params.append({"params": param_group, "lr": param_group_lr})
383393

@@ -388,7 +398,6 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
388398
optimizers.append(optimizer)
389399
optimizer = optimizers[0] # avoid error in the following code
390400

391-
print(len(grouped_params))
392401
logger.info(f"using {len(optimizers)} optimizers for fused optimizer groups")
393402

394403
else:
@@ -420,6 +429,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
420429

421430
# lr schedulerを用意する
422431
if args.fused_optimizer_groups:
432+
# prepare lr schedulers for each optimizer
423433
lr_schedulers = [train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) for optimizer in optimizers]
424434
lr_scheduler = lr_schedulers[0] # avoid error in the following code
425435
else:
@@ -472,6 +482,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
472482
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
473483

474484
if args.fused_backward_pass:
485+
# use fused optimizer for backward pass: other optimizers will be supported in the future
475486
import library.adafactor_fused
476487

477488
library.adafactor_fused.patch_adafactor_fused(optimizer)
@@ -488,16 +499,20 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group):
488499
parameter.register_post_accumulate_grad_hook(__grad_hook)
489500

490501
elif args.fused_optimizer_groups:
502+
# prepare for additional optimizers and lr schedulers
491503
for i in range(1, len(optimizers)):
492504
optimizers[i] = accelerator.prepare(optimizers[i])
493505
lr_schedulers[i] = accelerator.prepare(lr_schedulers[i])
494506

507+
# counters are used to determine when to step the optimizer
495508
global optimizer_hooked_count
496509
global num_parameters_per_group
497510
global parameter_optimizer_map
511+
498512
optimizer_hooked_count = {}
499513
num_parameters_per_group = [0] * len(optimizers)
500514
parameter_optimizer_map = {}
515+
501516
for opt_idx, optimizer in enumerate(optimizers):
502517
for param_group in optimizer.param_groups:
503518
for parameter in param_group["params"]:
@@ -511,7 +526,7 @@ def optimizer_hook(parameter: torch.Tensor):
511526
optimizer_hooked_count[i] += 1
512527
if optimizer_hooked_count[i] == num_parameters_per_group[i]:
513528
optimizers[i].step()
514-
optimizers[i].zero_grad()
529+
optimizers[i].zero_grad(set_to_none=True)
515530

516531
parameter.register_post_accumulate_grad_hook(optimizer_hook)
517532
parameter_optimizer_map[parameter] = opt_idx
@@ -593,7 +608,7 @@ def optimizer_hook(parameter: torch.Tensor):
593608
current_step.value = global_step
594609

595610
if args.fused_optimizer_groups:
596-
optimizer_hooked_count = {i: 0 for i in range(len(optimizers))}
611+
optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step
597612

598613
with accelerator.accumulate(*training_models):
599614
if "latents" in batch and batch["latents"] is not None:
@@ -725,14 +740,14 @@ def optimizer_hook(parameter: torch.Tensor):
725740
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
726741

727742
optimizer.step()
728-
elif args.fused_optimizer_groups:
729-
for i in range(1, len(optimizers)):
730-
lr_schedulers[i].step()
731-
732-
lr_scheduler.step()
733-
734-
if not (args.fused_backward_pass or args.fused_optimizer_groups):
743+
lr_scheduler.step()
735744
optimizer.zero_grad(set_to_none=True)
745+
else:
746+
# optimizer.step() and optimizer.zero_grad() are called in the optimizer hook
747+
lr_scheduler.step()
748+
if args.fused_optimizer_groups:
749+
for i in range(1, len(optimizers)):
750+
lr_schedulers[i].step()
736751

737752
# Checks if the accelerator has performed an optimization step behind the scenes
738753
if accelerator.sync_gradients:

0 commit comments

Comments
 (0)