@@ -493,17 +493,24 @@ def train(self, args):
493
493
# before resuming make hook for saving/loading to save/load the network weights only
494
494
def save_model_hook (models , weights , output_dir ):
495
495
# pop weights of other models than network to save only network weights
496
- # only main process or deepspeed https://github.com/huggingface/diffusers/issues/2606
497
- if accelerator .is_main_process or args .deepspeed :
496
+ if accelerator .is_main_process :
498
497
remove_indices = []
499
498
for i , model in enumerate (models ):
500
499
if not isinstance (model , type (accelerator .unwrap_model (network ))):
501
500
remove_indices .append (i )
502
501
for i in reversed (remove_indices ):
503
- if len (weights ) > i :
504
- weights .pop (i )
502
+ weights .pop (i )
505
503
# print(f"save model hook: {len(weights)} weights will be saved")
506
504
505
+ # save current ecpoch and step
506
+ train_state_file = os .path .join (output_dir , "train_state.json" )
507
+ # +1 is needed because the state is saved before current_step is set from global_step
508
+ logger .info (f"save train state to { train_state_file } at epoch { current_epoch .value } step { current_step .value + 1 } " )
509
+ with open (train_state_file , "w" , encoding = "utf-8" ) as f :
510
+ json .dump ({"current_epoch" : current_epoch .value , "current_step" : current_step .value + 1 }, f )
511
+
512
+ steps_from_state = None
513
+
507
514
def load_model_hook (models , input_dir ):
508
515
# remove models except network
509
516
remove_indices = []
@@ -514,6 +521,15 @@ def load_model_hook(models, input_dir):
514
521
models .pop (i )
515
522
# print(f"load model hook: {len(models)} models will be loaded")
516
523
524
+ # load current epoch and step to
525
+ nonlocal steps_from_state
526
+ train_state_file = os .path .join (input_dir , "train_state.json" )
527
+ if os .path .exists (train_state_file ):
528
+ with open (train_state_file , "r" , encoding = "utf-8" ) as f :
529
+ data = json .load (f )
530
+ steps_from_state = data ["current_step" ]
531
+ logger .info (f"load train state from { train_state_file } : { data } " )
532
+
517
533
accelerator .register_save_state_pre_hook (save_model_hook )
518
534
accelerator .register_load_state_pre_hook (load_model_hook )
519
535
@@ -757,7 +773,53 @@ def load_model_hook(models, input_dir):
757
773
if key in metadata :
758
774
minimum_metadata [key ] = metadata [key ]
759
775
760
- progress_bar = tqdm (range (args .max_train_steps ), smoothing = 0 , disable = not accelerator .is_local_main_process , desc = "steps" )
776
+ # calculate steps to skip when resuming or starting from a specific step
777
+ initial_step = 0
778
+ if args .initial_epoch is not None or args .initial_step is not None :
779
+ # if initial_epoch or initial_step is specified, steps_from_state is ignored even when resuming
780
+ if steps_from_state is not None :
781
+ logger .warning (
782
+ "steps from the state is ignored because initial_step is specified / initial_stepが指定されているため、stateからのステップ数は無視されます"
783
+ )
784
+ if args .initial_step is not None :
785
+ initial_step = args .initial_step
786
+ else :
787
+ # num steps per epoch is calculated by num_processes and gradient_accumulation_steps
788
+ initial_step = (args .initial_epoch - 1 ) * math .ceil (
789
+ len (train_dataloader ) / accelerator .num_processes / args .gradient_accumulation_steps
790
+ )
791
+ else :
792
+ # if initial_epoch and initial_step are not specified, steps_from_state is used when resuming
793
+ if steps_from_state is not None :
794
+ initial_step = steps_from_state
795
+ steps_from_state = None
796
+
797
+ if initial_step > 0 :
798
+ assert (
799
+ args .max_train_steps > initial_step
800
+ ), f"max_train_steps should be greater than initial step / max_train_stepsは初期ステップより大きい必要があります: { args .max_train_steps } vs { initial_step } "
801
+
802
+ progress_bar = tqdm (
803
+ range (args .max_train_steps - initial_step ), smoothing = 0 , disable = not accelerator .is_local_main_process , desc = "steps"
804
+ )
805
+
806
+ epoch_to_start = 0
807
+ if initial_step > 0 :
808
+ if args .skip_until_initial_step :
809
+ # if skip_until_initial_step is specified, load data and discard it to ensure the same data is used
810
+ if not args .resume :
811
+ logger .info (
812
+ f"initial_step is specified but not resuming. lr scheduler will be started from the beginning / initial_stepが指定されていますがresumeしていないため、lr schedulerは最初から始まります"
813
+ )
814
+ logger .info (f"skipping { initial_step } steps / { initial_step } ステップをスキップします" )
815
+ initial_step *= args .gradient_accumulation_steps
816
+ else :
817
+ # if not, only epoch no is skipped for informative purpose
818
+ epoch_to_start = initial_step // math .ceil (
819
+ len (train_dataloader ) / args .gradient_accumulation_steps
820
+ )
821
+ initial_step = 0 # do not skip
822
+
761
823
global_step = 0
762
824
763
825
noise_scheduler = DDPMScheduler (
@@ -816,16 +878,29 @@ def remove_model(old_ckpt_name):
816
878
self .sample_images (accelerator , args , 0 , global_step , accelerator .device , vae , tokenizer , text_encoder , unet )
817
879
818
880
# training loop
819
- for epoch in range (num_train_epochs ):
881
+ for skip_epoch in range (epoch_to_start ): # skip epochs
882
+ logger .info (f"skipping epoch { skip_epoch + 1 } because initial_step (multiplied) is { initial_step } " )
883
+ initial_step -= len (train_dataloader )
884
+
885
+ for epoch in range (epoch_to_start , num_train_epochs ):
820
886
accelerator .print (f"\n epoch { epoch + 1 } /{ num_train_epochs } " )
821
887
current_epoch .value = epoch + 1
822
888
823
889
metadata ["ss_epoch" ] = str (epoch + 1 )
824
890
825
891
accelerator .unwrap_model (network ).on_epoch_start (text_encoder , unet )
826
892
827
- for step , batch in enumerate (train_dataloader ):
893
+ skipped_dataloader = None
894
+ if initial_step > 0 :
895
+ skipped_dataloader = accelerator .skip_first_batches (train_dataloader , initial_step - 1 )
896
+ initial_step = 1
897
+
898
+ for step , batch in enumerate (skipped_dataloader or train_dataloader ):
828
899
current_step .value = global_step
900
+ if initial_step > 0 :
901
+ initial_step -= 1
902
+ continue
903
+
829
904
with accelerator .accumulate (training_model ):
830
905
on_step_start (text_encoder , unet )
831
906
@@ -1126,6 +1201,25 @@ def setup_parser() -> argparse.ArgumentParser:
1126
1201
action = "store_true" ,
1127
1202
help = "do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う" ,
1128
1203
)
1204
+ parser .add_argument (
1205
+ "--skip_until_initial_step" ,
1206
+ action = "store_true" ,
1207
+ help = "skip training until initial_step is reached / initial_stepに到達するまで学習をスキップする" ,
1208
+ )
1209
+ parser .add_argument (
1210
+ "--initial_epoch" ,
1211
+ type = int ,
1212
+ default = None ,
1213
+ help = "initial epoch number, 1 means first epoch (same as not specifying). NOTE: initial_epoch/step doesn't affect to lr scheduler. Which means lr scheduler will start from 0 without `--resume`."
1214
+ + " / 初期エポック数、1で最初のエポック(未指定時と同じ)。注意:initial_epoch/stepはlr schedulerに影響しないため、`--resume`しない場合はlr schedulerは0から始まる" ,
1215
+ )
1216
+ parser .add_argument (
1217
+ "--initial_step" ,
1218
+ type = int ,
1219
+ default = None ,
1220
+ help = "initial step number including all epochs, 0 means first step (same as not specifying). overwrites initial_epoch."
1221
+ + " / 初期ステップ数、全エポックを含むステップ数、0で最初のステップ(未指定時と同じ)。initial_epochを上書きする" ,
1222
+ )
1129
1223
# parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio")
1130
1224
# parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio")
1131
1225
# parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio")
0 commit comments