@@ -345,8 +345,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
345
345
346
346
# calculate number of trainable parameters
347
347
n_params = 0
348
- for params in params_to_optimize :
349
- for p in params ["params" ]:
348
+ for group in params_to_optimize :
349
+ for p in group ["params" ]:
350
350
n_params += p .numel ()
351
351
352
352
accelerator .print (f"train unet: { train_unet } , text_encoder1: { train_text_encoder1 } , text_encoder2: { train_text_encoder2 } " )
@@ -355,7 +355,44 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
355
355
356
356
# 学習に必要なクラスを準備する
357
357
accelerator .print ("prepare optimizer, data loader etc." )
358
- _ , _ , optimizer = train_util .get_optimizer (args , trainable_params = params_to_optimize )
358
+
359
+ if args .fused_optimizer_groups :
360
+ # calculate total number of parameters
361
+ n_total_params = sum (len (params ["params" ]) for params in params_to_optimize )
362
+ params_per_group = math .ceil (n_total_params / args .fused_optimizer_groups )
363
+
364
+ # split params into groups
365
+ grouped_params = []
366
+ param_group = []
367
+ param_group_lr = - 1
368
+ for group in params_to_optimize :
369
+ lr = group ["lr" ]
370
+ for p in group ["params" ]:
371
+ if lr != param_group_lr :
372
+ if param_group :
373
+ grouped_params .append ({"params" : param_group , "lr" : param_group_lr })
374
+ param_group = []
375
+ param_group_lr = lr
376
+ param_group .append (p )
377
+ if len (param_group ) == params_per_group :
378
+ grouped_params .append ({"params" : param_group , "lr" : param_group_lr })
379
+ param_group = []
380
+ param_group_lr = - 1
381
+ if param_group :
382
+ grouped_params .append ({"params" : param_group , "lr" : param_group_lr })
383
+
384
+ # prepare optimizers for each group
385
+ optimizers = []
386
+ for group in grouped_params :
387
+ _ , _ , optimizer = train_util .get_optimizer (args , trainable_params = [group ])
388
+ optimizers .append (optimizer )
389
+ optimizer = optimizers [0 ] # avoid error in the following code
390
+
391
+ print (len (grouped_params ))
392
+ logger .info (f"using { len (optimizers )} optimizers for fused optimizer groups" )
393
+
394
+ else :
395
+ _ , _ , optimizer = train_util .get_optimizer (args , trainable_params = params_to_optimize )
359
396
360
397
# dataloaderを準備する
361
398
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
@@ -382,7 +419,11 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
382
419
train_dataset_group .set_max_train_steps (args .max_train_steps )
383
420
384
421
# lr schedulerを用意する
385
- lr_scheduler = train_util .get_scheduler_fix (args , optimizer , accelerator .num_processes )
422
+ if args .fused_optimizer_groups :
423
+ lr_schedulers = [train_util .get_scheduler_fix (args , optimizer , accelerator .num_processes ) for optimizer in optimizers ]
424
+ lr_scheduler = lr_schedulers [0 ] # avoid error in the following code
425
+ else :
426
+ lr_scheduler = train_util .get_scheduler_fix (args , optimizer , accelerator .num_processes )
386
427
387
428
# 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
388
429
if args .full_fp16 :
@@ -432,10 +473,12 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
432
473
433
474
if args .fused_backward_pass :
434
475
import library .adafactor_fused
476
+
435
477
library .adafactor_fused .patch_adafactor_fused (optimizer )
436
478
for param_group in optimizer .param_groups :
437
479
for parameter in param_group ["params" ]:
438
480
if parameter .requires_grad :
481
+
439
482
def __grad_hook (tensor : torch .Tensor , param_group = param_group ):
440
483
if accelerator .sync_gradients and args .max_grad_norm != 0.0 :
441
484
accelerator .clip_grad_norm_ (tensor , args .max_grad_norm )
@@ -444,6 +487,36 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group):
444
487
445
488
parameter .register_post_accumulate_grad_hook (__grad_hook )
446
489
490
+ elif args .fused_optimizer_groups :
491
+ for i in range (1 , len (optimizers )):
492
+ optimizers [i ] = accelerator .prepare (optimizers [i ])
493
+ lr_schedulers [i ] = accelerator .prepare (lr_schedulers [i ])
494
+
495
+ global optimizer_hooked_count
496
+ global num_parameters_per_group
497
+ global parameter_optimizer_map
498
+ optimizer_hooked_count = {}
499
+ num_parameters_per_group = [0 ] * len (optimizers )
500
+ parameter_optimizer_map = {}
501
+ for opt_idx , optimizer in enumerate (optimizers ):
502
+ for param_group in optimizer .param_groups :
503
+ for parameter in param_group ["params" ]:
504
+ if parameter .requires_grad :
505
+
506
+ def optimizer_hook (parameter : torch .Tensor ):
507
+ if accelerator .sync_gradients and args .max_grad_norm != 0.0 :
508
+ accelerator .clip_grad_norm_ (parameter , args .max_grad_norm )
509
+
510
+ i = parameter_optimizer_map [parameter ]
511
+ optimizer_hooked_count [i ] += 1
512
+ if optimizer_hooked_count [i ] == num_parameters_per_group [i ]:
513
+ optimizers [i ].step ()
514
+ optimizers [i ].zero_grad ()
515
+
516
+ parameter .register_post_accumulate_grad_hook (optimizer_hook )
517
+ parameter_optimizer_map [parameter ] = opt_idx
518
+ num_parameters_per_group [opt_idx ] += 1
519
+
447
520
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
448
521
if args .cache_text_encoder_outputs :
449
522
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
@@ -518,6 +591,10 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group):
518
591
519
592
for step , batch in enumerate (train_dataloader ):
520
593
current_step .value = global_step
594
+
595
+ if args .fused_optimizer_groups :
596
+ optimizer_hooked_count = {i : 0 for i in range (len (optimizers ))}
597
+
521
598
with accelerator .accumulate (* training_models ):
522
599
if "latents" in batch and batch ["latents" ] is not None :
523
600
latents = batch ["latents" ].to (accelerator .device ).to (dtype = weight_dtype )
@@ -596,7 +673,9 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group):
596
673
597
674
# Sample noise, sample a random timestep for each image, and add noise to the latents,
598
675
# with noise offset and/or multires noise if specified
599
- noise , noisy_latents , timesteps , huber_c = train_util .get_noise_noisy_latents_and_timesteps (args , noise_scheduler , latents )
676
+ noise , noisy_latents , timesteps , huber_c = train_util .get_noise_noisy_latents_and_timesteps (
677
+ args , noise_scheduler , latents
678
+ )
600
679
601
680
noisy_latents = noisy_latents .to (weight_dtype ) # TODO check why noisy_latents is not weight_dtype
602
681
@@ -614,7 +693,9 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group):
614
693
or args .masked_loss
615
694
):
616
695
# do not mean over batch dimension for snr weight or scale v-pred loss
617
- loss = train_util .conditional_loss (noise_pred .float (), target .float (), reduction = "none" , loss_type = args .loss_type , huber_c = huber_c )
696
+ loss = train_util .conditional_loss (
697
+ noise_pred .float (), target .float (), reduction = "none" , loss_type = args .loss_type , huber_c = huber_c
698
+ )
618
699
if args .masked_loss :
619
700
loss = apply_masked_loss (loss , batch )
620
701
loss = loss .mean ([1 , 2 , 3 ])
@@ -630,21 +711,28 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group):
630
711
631
712
loss = loss .mean () # mean over batch dimension
632
713
else :
633
- loss = train_util .conditional_loss (noise_pred .float (), target .float (), reduction = "mean" , loss_type = args .loss_type , huber_c = huber_c )
714
+ loss = train_util .conditional_loss (
715
+ noise_pred .float (), target .float (), reduction = "mean" , loss_type = args .loss_type , huber_c = huber_c
716
+ )
634
717
635
718
accelerator .backward (loss )
636
719
637
- if not args .fused_backward_pass :
720
+ if not ( args .fused_backward_pass or args . fused_optimizer_groups ) :
638
721
if accelerator .sync_gradients and args .max_grad_norm != 0.0 :
639
722
params_to_clip = []
640
723
for m in training_models :
641
724
params_to_clip .extend (m .parameters ())
642
725
accelerator .clip_grad_norm_ (params_to_clip , args .max_grad_norm )
643
726
644
727
optimizer .step ()
728
+ elif args .fused_optimizer_groups :
729
+ for i in range (1 , len (optimizers )):
730
+ lr_schedulers [i ].step ()
645
731
646
732
lr_scheduler .step ()
647
- optimizer .zero_grad (set_to_none = True )
733
+
734
+ if not (args .fused_backward_pass or args .fused_optimizer_groups ):
735
+ optimizer .zero_grad (set_to_none = True )
648
736
649
737
# Checks if the accelerator has performed an optimization step behind the scenes
650
738
if accelerator .sync_gradients :
@@ -753,7 +841,7 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group):
753
841
754
842
accelerator .end_training ()
755
843
756
- if args .save_state or args .save_state_on_train_end :
844
+ if args .save_state or args .save_state_on_train_end :
757
845
train_util .save_state_on_train_end (args , accelerator )
758
846
759
847
del accelerator # この後メモリを使うのでこれは消す
@@ -822,6 +910,12 @@ def setup_parser() -> argparse.ArgumentParser:
822
910
help = f"learning rates for each block of U-Net, comma-separated, { UNET_NUM_BLOCKS_FOR_BLOCK_LR } values / "
823
911
+ f"U-Netの各ブロックの学習率、カンマ区切り、{ UNET_NUM_BLOCKS_FOR_BLOCK_LR } 個の値" ,
824
912
)
913
+ parser .add_argument (
914
+ "--fused_optimizer_groups" ,
915
+ type = int ,
916
+ default = None ,
917
+ help = "number of optimizers for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizer数" ,
918
+ )
825
919
return parser
826
920
827
921
0 commit comments