@@ -357,27 +357,37 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
357
357
accelerator .print ("prepare optimizer, data loader etc." )
358
358
359
359
if args .fused_optimizer_groups :
360
+ # fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html
361
+ # Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each group of parameters.
362
+ # This balances memory usage and management complexity.
363
+
360
364
# calculate total number of parameters
361
365
n_total_params = sum (len (params ["params" ]) for params in params_to_optimize )
362
366
params_per_group = math .ceil (n_total_params / args .fused_optimizer_groups )
363
367
364
- # split params into groups
368
+ # split params into groups, keeping the learning rate the same for all params in a group
369
+ # this will increase the number of groups if the learning rate is different for different params (e.g. U-Net and text encoders)
365
370
grouped_params = []
366
371
param_group = []
367
372
param_group_lr = - 1
368
373
for group in params_to_optimize :
369
374
lr = group ["lr" ]
370
375
for p in group ["params" ]:
376
+ # if the learning rate is different for different params, start a new group
371
377
if lr != param_group_lr :
372
378
if param_group :
373
379
grouped_params .append ({"params" : param_group , "lr" : param_group_lr })
374
380
param_group = []
375
381
param_group_lr = lr
382
+
376
383
param_group .append (p )
384
+
385
+ # if the group has enough parameters, start a new group
377
386
if len (param_group ) == params_per_group :
378
387
grouped_params .append ({"params" : param_group , "lr" : param_group_lr })
379
388
param_group = []
380
389
param_group_lr = - 1
390
+
381
391
if param_group :
382
392
grouped_params .append ({"params" : param_group , "lr" : param_group_lr })
383
393
@@ -388,7 +398,6 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
388
398
optimizers .append (optimizer )
389
399
optimizer = optimizers [0 ] # avoid error in the following code
390
400
391
- print (len (grouped_params ))
392
401
logger .info (f"using { len (optimizers )} optimizers for fused optimizer groups" )
393
402
394
403
else :
@@ -420,6 +429,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
420
429
421
430
# lr schedulerを用意する
422
431
if args .fused_optimizer_groups :
432
+ # prepare lr schedulers for each optimizer
423
433
lr_schedulers = [train_util .get_scheduler_fix (args , optimizer , accelerator .num_processes ) for optimizer in optimizers ]
424
434
lr_scheduler = lr_schedulers [0 ] # avoid error in the following code
425
435
else :
@@ -472,6 +482,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
472
482
optimizer , train_dataloader , lr_scheduler = accelerator .prepare (optimizer , train_dataloader , lr_scheduler )
473
483
474
484
if args .fused_backward_pass :
485
+ # use fused optimizer for backward pass: other optimizers will be supported in the future
475
486
import library .adafactor_fused
476
487
477
488
library .adafactor_fused .patch_adafactor_fused (optimizer )
@@ -488,16 +499,20 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group):
488
499
parameter .register_post_accumulate_grad_hook (__grad_hook )
489
500
490
501
elif args .fused_optimizer_groups :
502
+ # prepare for additional optimizers and lr schedulers
491
503
for i in range (1 , len (optimizers )):
492
504
optimizers [i ] = accelerator .prepare (optimizers [i ])
493
505
lr_schedulers [i ] = accelerator .prepare (lr_schedulers [i ])
494
506
507
+ # counters are used to determine when to step the optimizer
495
508
global optimizer_hooked_count
496
509
global num_parameters_per_group
497
510
global parameter_optimizer_map
511
+
498
512
optimizer_hooked_count = {}
499
513
num_parameters_per_group = [0 ] * len (optimizers )
500
514
parameter_optimizer_map = {}
515
+
501
516
for opt_idx , optimizer in enumerate (optimizers ):
502
517
for param_group in optimizer .param_groups :
503
518
for parameter in param_group ["params" ]:
@@ -511,7 +526,7 @@ def optimizer_hook(parameter: torch.Tensor):
511
526
optimizer_hooked_count [i ] += 1
512
527
if optimizer_hooked_count [i ] == num_parameters_per_group [i ]:
513
528
optimizers [i ].step ()
514
- optimizers [i ].zero_grad ()
529
+ optimizers [i ].zero_grad (set_to_none = True )
515
530
516
531
parameter .register_post_accumulate_grad_hook (optimizer_hook )
517
532
parameter_optimizer_map [parameter ] = opt_idx
@@ -593,7 +608,7 @@ def optimizer_hook(parameter: torch.Tensor):
593
608
current_step .value = global_step
594
609
595
610
if args .fused_optimizer_groups :
596
- optimizer_hooked_count = {i : 0 for i in range (len (optimizers ))}
611
+ optimizer_hooked_count = {i : 0 for i in range (len (optimizers ))} # reset counter for each step
597
612
598
613
with accelerator .accumulate (* training_models ):
599
614
if "latents" in batch and batch ["latents" ] is not None :
@@ -725,14 +740,14 @@ def optimizer_hook(parameter: torch.Tensor):
725
740
accelerator .clip_grad_norm_ (params_to_clip , args .max_grad_norm )
726
741
727
742
optimizer .step ()
728
- elif args .fused_optimizer_groups :
729
- for i in range (1 , len (optimizers )):
730
- lr_schedulers [i ].step ()
731
-
732
- lr_scheduler .step ()
733
-
734
- if not (args .fused_backward_pass or args .fused_optimizer_groups ):
743
+ lr_scheduler .step ()
735
744
optimizer .zero_grad (set_to_none = True )
745
+ else :
746
+ # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook
747
+ lr_scheduler .step ()
748
+ if args .fused_optimizer_groups :
749
+ for i in range (1 , len (optimizers )):
750
+ lr_schedulers [i ].step ()
736
751
737
752
# Checks if the accelerator has performed an optimization step behind the scenes
738
753
if accelerator .sync_gradients :
0 commit comments