Skip to content

Commit 321e24d

Browse files
authored
Merge pull request kohya-ss#1353 from KohakuBlueleaf/train_resume_step
Resume correct step for "resume from state" feature.
2 parents e5bab69 + 3eb27ce commit 321e24d

File tree

2 files changed

+110
-9
lines changed

2 files changed

+110
-9
lines changed

Diff for: library/train_util.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -657,8 +657,15 @@ def set_caching_mode(self, mode):
657657

658658
def set_current_epoch(self, epoch):
659659
if not self.current_epoch == epoch: # epochが切り替わったらバケツをシャッフルする
660-
self.shuffle_buckets()
661-
self.current_epoch = epoch
660+
if epoch > self.current_epoch:
661+
logger.info("epoch is incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch))
662+
num_epochs = epoch - self.current_epoch
663+
for _ in range(num_epochs):
664+
self.current_epoch += 1
665+
self.shuffle_buckets()
666+
else:
667+
logger.warning("epoch is not incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch))
668+
self.current_epoch = epoch
662669

663670
def set_current_step(self, step):
664671
self.current_step = step

Diff for: train_network.py

+101-7
Original file line numberDiff line numberDiff line change
@@ -493,17 +493,24 @@ def train(self, args):
493493
# before resuming make hook for saving/loading to save/load the network weights only
494494
def save_model_hook(models, weights, output_dir):
495495
# pop weights of other models than network to save only network weights
496-
# only main process or deepspeed https://github.com/huggingface/diffusers/issues/2606
497-
if accelerator.is_main_process or args.deepspeed:
496+
if accelerator.is_main_process:
498497
remove_indices = []
499498
for i, model in enumerate(models):
500499
if not isinstance(model, type(accelerator.unwrap_model(network))):
501500
remove_indices.append(i)
502501
for i in reversed(remove_indices):
503-
if len(weights) > i:
504-
weights.pop(i)
502+
weights.pop(i)
505503
# print(f"save model hook: {len(weights)} weights will be saved")
506504

505+
# save current ecpoch and step
506+
train_state_file = os.path.join(output_dir, "train_state.json")
507+
# +1 is needed because the state is saved before current_step is set from global_step
508+
logger.info(f"save train state to {train_state_file} at epoch {current_epoch.value} step {current_step.value+1}")
509+
with open(train_state_file, "w", encoding="utf-8") as f:
510+
json.dump({"current_epoch": current_epoch.value, "current_step": current_step.value + 1}, f)
511+
512+
steps_from_state = None
513+
507514
def load_model_hook(models, input_dir):
508515
# remove models except network
509516
remove_indices = []
@@ -514,6 +521,15 @@ def load_model_hook(models, input_dir):
514521
models.pop(i)
515522
# print(f"load model hook: {len(models)} models will be loaded")
516523

524+
# load current epoch and step to
525+
nonlocal steps_from_state
526+
train_state_file = os.path.join(input_dir, "train_state.json")
527+
if os.path.exists(train_state_file):
528+
with open(train_state_file, "r", encoding="utf-8") as f:
529+
data = json.load(f)
530+
steps_from_state = data["current_step"]
531+
logger.info(f"load train state from {train_state_file}: {data}")
532+
517533
accelerator.register_save_state_pre_hook(save_model_hook)
518534
accelerator.register_load_state_pre_hook(load_model_hook)
519535

@@ -757,7 +773,53 @@ def load_model_hook(models, input_dir):
757773
if key in metadata:
758774
minimum_metadata[key] = metadata[key]
759775

760-
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
776+
# calculate steps to skip when resuming or starting from a specific step
777+
initial_step = 0
778+
if args.initial_epoch is not None or args.initial_step is not None:
779+
# if initial_epoch or initial_step is specified, steps_from_state is ignored even when resuming
780+
if steps_from_state is not None:
781+
logger.warning(
782+
"steps from the state is ignored because initial_step is specified / initial_stepが指定されているため、stateからのステップ数は無視されます"
783+
)
784+
if args.initial_step is not None:
785+
initial_step = args.initial_step
786+
else:
787+
# num steps per epoch is calculated by num_processes and gradient_accumulation_steps
788+
initial_step = (args.initial_epoch - 1) * math.ceil(
789+
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
790+
)
791+
else:
792+
# if initial_epoch and initial_step are not specified, steps_from_state is used when resuming
793+
if steps_from_state is not None:
794+
initial_step = steps_from_state
795+
steps_from_state = None
796+
797+
if initial_step > 0:
798+
assert (
799+
args.max_train_steps > initial_step
800+
), f"max_train_steps should be greater than initial step / max_train_stepsは初期ステップより大きい必要があります: {args.max_train_steps} vs {initial_step}"
801+
802+
progress_bar = tqdm(
803+
range(args.max_train_steps - initial_step), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps"
804+
)
805+
806+
epoch_to_start = 0
807+
if initial_step > 0:
808+
if args.skip_until_initial_step:
809+
# if skip_until_initial_step is specified, load data and discard it to ensure the same data is used
810+
if not args.resume:
811+
logger.info(
812+
f"initial_step is specified but not resuming. lr scheduler will be started from the beginning / initial_stepが指定されていますがresumeしていないため、lr schedulerは最初から始まります"
813+
)
814+
logger.info(f"skipping {initial_step} steps / {initial_step}ステップをスキップします")
815+
initial_step *= args.gradient_accumulation_steps
816+
else:
817+
# if not, only epoch no is skipped for informative purpose
818+
epoch_to_start = initial_step // math.ceil(
819+
len(train_dataloader) / args.gradient_accumulation_steps
820+
)
821+
initial_step = 0 # do not skip
822+
761823
global_step = 0
762824

763825
noise_scheduler = DDPMScheduler(
@@ -816,16 +878,29 @@ def remove_model(old_ckpt_name):
816878
self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
817879

818880
# training loop
819-
for epoch in range(num_train_epochs):
881+
for skip_epoch in range(epoch_to_start): # skip epochs
882+
logger.info(f"skipping epoch {skip_epoch+1} because initial_step (multiplied) is {initial_step}")
883+
initial_step -= len(train_dataloader)
884+
885+
for epoch in range(epoch_to_start, num_train_epochs):
820886
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
821887
current_epoch.value = epoch + 1
822888

823889
metadata["ss_epoch"] = str(epoch + 1)
824890

825891
accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet)
826892

827-
for step, batch in enumerate(train_dataloader):
893+
skipped_dataloader = None
894+
if initial_step > 0:
895+
skipped_dataloader = accelerator.skip_first_batches(train_dataloader, initial_step-1)
896+
initial_step = 1
897+
898+
for step, batch in enumerate(skipped_dataloader or train_dataloader):
828899
current_step.value = global_step
900+
if initial_step > 0:
901+
initial_step -= 1
902+
continue
903+
829904
with accelerator.accumulate(training_model):
830905
on_step_start(text_encoder, unet)
831906

@@ -1126,6 +1201,25 @@ def setup_parser() -> argparse.ArgumentParser:
11261201
action="store_true",
11271202
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
11281203
)
1204+
parser.add_argument(
1205+
"--skip_until_initial_step",
1206+
action="store_true",
1207+
help="skip training until initial_step is reached / initial_stepに到達するまで学習をスキップする",
1208+
)
1209+
parser.add_argument(
1210+
"--initial_epoch",
1211+
type=int,
1212+
default=None,
1213+
help="initial epoch number, 1 means first epoch (same as not specifying). NOTE: initial_epoch/step doesn't affect to lr scheduler. Which means lr scheduler will start from 0 without `--resume`."
1214+
+ " / 初期エポック数、1で最初のエポック(未指定時と同じ)。注意:initial_epoch/stepはlr schedulerに影響しないため、`--resume`しない場合はlr schedulerは0から始まる",
1215+
)
1216+
parser.add_argument(
1217+
"--initial_step",
1218+
type=int,
1219+
default=None,
1220+
help="initial step number including all epochs, 0 means first step (same as not specifying). overwrites initial_epoch."
1221+
+ " / 初期ステップ数、全エポックを含むステップ数、0で最初のステップ(未指定時と同じ)。initial_epochを上書きする",
1222+
)
11291223
# parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio")
11301224
# parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio")
11311225
# parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio")

0 commit comments

Comments
 (0)