Skip to content

Commit 22413a5

Browse files
authored
Merge pull request #1359 from kohya-ss/train_resume_step
Train resume step
2 parents 3259928 + 18d7597 commit 22413a5

File tree

3 files changed

+126
-5
lines changed

3 files changed

+126
-5
lines changed

README.md

+12
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,12 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
178178

179179
- The ControlNet training script `train_controlnet.py` for SD1.5/2.x was not working, but it has been fixed. PR [#1284](https://github.com/kohya-ss/sd-scripts/pull/1284) Thanks to sdbds!
180180

181+
- `train_network.py` and `sdxl_train_network.py` now restore the order/position of data loading from DataSet when resuming training. PR [#1353](https://github.com/kohya-ss/sd-scripts/pull/1353) [#1359](https://github.com/kohya-ss/sd-scripts/pull/1359) Thanks to KohakuBlueleaf!
182+
- This resolves the issue where the order of data loading from DataSet changes when resuming training.
183+
- Specify the `--skip_until_initial_step` option to skip data loading until the specified step. If not specified, data loading starts from the beginning of the DataSet (same as before).
184+
- If `--resume` is specified, the step saved in the state is used.
185+
- Specify the `--initial_step` or `--initial_epoch` option to skip data loading until the specified step or epoch. Use these options in conjunction with `--skip_until_initial_step`. These options can be used without `--resume` (use them when resuming training with `--network_weights`).
186+
181187
- An option `--disable_mmap_load_safetensors` is added to disable memory mapping when loading the model's .safetensors in SDXL. PR [#1266](https://github.com/kohya-ss/sd-scripts/pull/1266) Thanks to Zovjsra!
182188
- It seems that the model file loading is faster in the WSL environment etc.
183189
- Available in `sdxl_train.py`, `sdxl_train_network.py`, `sdxl_train_textual_inversion.py`, and `sdxl_train_control_net_lllite.py`.
@@ -235,6 +241,12 @@ https://github.com/kohya-ss/sd-scripts/pull/1290) Thanks to frodo821!
235241

236242
- SD1.5/2.x 用の ControlNet 学習スクリプト `train_controlnet.py` が動作しなくなっていたのが修正されました。PR [#1284](https://github.com/kohya-ss/sd-scripts/pull/1284) sdbds 氏に感謝します。
237243

244+
- `train_network.py` および `sdxl_train_network.py` で、学習再開時に DataSet の読み込み順についても復元できるようになりました。PR [#1353](https://github.com/kohya-ss/sd-scripts/pull/1353) [#1359](https://github.com/kohya-ss/sd-scripts/pull/1359) KohakuBlueleaf 氏に感謝します。
245+
- これにより、学習再開時に DataSet の読み込み順が変わってしまう問題が解消されます。
246+
- `--skip_until_initial_step` オプションを指定すると、指定したステップまで DataSet 読み込みをスキップします。指定しない場合の動作は変わりません(DataSet の最初から読み込みます)
247+
- `--resume` オプションを指定すると、state に保存されたステップ数が使用されます。
248+
- `--initial_step` または `--initial_epoch` オプションを指定すると、指定したステップまたはエポックまで DataSet 読み込みをスキップします。これらのオプションは `--skip_until_initial_step` と併用してください。またこれらのオプションは `--resume` と併用しなくても使えます(`--network_weights` を用いた学習再開時などにお使いください )。
249+
238250
- SDXL でモデルの .safetensors を読み込む際にメモリマッピングを無効化するオプション `--disable_mmap_load_safetensors` が追加されました。PR [#1266](https://github.com/kohya-ss/sd-scripts/pull/1266) Zovjsra 氏に感謝します。
239251
- WSL 環境等でモデルファイルの読み込みが高速化されるようです。
240252
- `sdxl_train.py``sdxl_train_network.py``sdxl_train_textual_inversion.py``sdxl_train_control_net_lllite.py` で使用可能です。

library/train_util.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -657,8 +657,16 @@ def set_caching_mode(self, mode):
657657

658658
def set_current_epoch(self, epoch):
659659
if not self.current_epoch == epoch: # epochが切り替わったらバケツをシャッフルする
660-
self.shuffle_buckets()
661-
self.current_epoch = epoch
660+
if epoch > self.current_epoch:
661+
logger.info("epoch is incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch))
662+
num_epochs = epoch - self.current_epoch
663+
for _ in range(num_epochs):
664+
self.current_epoch += 1
665+
self.shuffle_buckets()
666+
# self.current_epoch seem to be set to 0 again in the next epoch. it may be caused by skipped_dataloader?
667+
else:
668+
logger.warning("epoch is not incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch))
669+
self.current_epoch = epoch
662670

663671
def set_current_step(self, step):
664672
self.current_step = step
@@ -5553,6 +5561,8 @@ def add(self, *, epoch: int, step: int, loss: float) -> None:
55535561
if epoch == 0:
55545562
self.loss_list.append(loss)
55555563
else:
5564+
while len(self.loss_list) <= step:
5565+
self.loss_list.append(0.0)
55565566
self.loss_total -= self.loss_list[step]
55575567
self.loss_list[step] = loss
55585568
self.loss_total += loss

train_network.py

+102-3
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,15 @@ def save_model_hook(models, weights, output_dir):
504504
weights.pop(i)
505505
# print(f"save model hook: {len(weights)} weights will be saved")
506506

507+
# save current ecpoch and step
508+
train_state_file = os.path.join(output_dir, "train_state.json")
509+
# +1 is needed because the state is saved before current_step is set from global_step
510+
logger.info(f"save train state to {train_state_file} at epoch {current_epoch.value} step {current_step.value+1}")
511+
with open(train_state_file, "w", encoding="utf-8") as f:
512+
json.dump({"current_epoch": current_epoch.value, "current_step": current_step.value + 1}, f)
513+
514+
steps_from_state = None
515+
507516
def load_model_hook(models, input_dir):
508517
# remove models except network
509518
remove_indices = []
@@ -514,6 +523,15 @@ def load_model_hook(models, input_dir):
514523
models.pop(i)
515524
# print(f"load model hook: {len(models)} models will be loaded")
516525

526+
# load current epoch and step to
527+
nonlocal steps_from_state
528+
train_state_file = os.path.join(input_dir, "train_state.json")
529+
if os.path.exists(train_state_file):
530+
with open(train_state_file, "r", encoding="utf-8") as f:
531+
data = json.load(f)
532+
steps_from_state = data["current_step"]
533+
logger.info(f"load train state from {train_state_file}: {data}")
534+
517535
accelerator.register_save_state_pre_hook(save_model_hook)
518536
accelerator.register_load_state_pre_hook(load_model_hook)
519537

@@ -757,7 +775,54 @@ def load_model_hook(models, input_dir):
757775
if key in metadata:
758776
minimum_metadata[key] = metadata[key]
759777

760-
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
778+
# calculate steps to skip when resuming or starting from a specific step
779+
initial_step = 0
780+
if args.initial_epoch is not None or args.initial_step is not None:
781+
# if initial_epoch or initial_step is specified, steps_from_state is ignored even when resuming
782+
if steps_from_state is not None:
783+
logger.warning(
784+
"steps from the state is ignored because initial_step is specified / initial_stepが指定されているため、stateからのステップ数は無視されます"
785+
)
786+
if args.initial_step is not None:
787+
initial_step = args.initial_step
788+
else:
789+
# num steps per epoch is calculated by num_processes and gradient_accumulation_steps
790+
initial_step = (args.initial_epoch - 1) * math.ceil(
791+
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
792+
)
793+
else:
794+
# if initial_epoch and initial_step are not specified, steps_from_state is used when resuming
795+
if steps_from_state is not None:
796+
initial_step = steps_from_state
797+
steps_from_state = None
798+
799+
if initial_step > 0:
800+
assert (
801+
args.max_train_steps > initial_step
802+
), f"max_train_steps should be greater than initial step / max_train_stepsは初期ステップより大きい必要があります: {args.max_train_steps} vs {initial_step}"
803+
804+
progress_bar = tqdm(
805+
range(args.max_train_steps - initial_step), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps"
806+
)
807+
808+
epoch_to_start = 0
809+
if initial_step > 0:
810+
if args.skip_until_initial_step:
811+
# if skip_until_initial_step is specified, load data and discard it to ensure the same data is used
812+
if not args.resume:
813+
logger.info(
814+
f"initial_step is specified but not resuming. lr scheduler will be started from the beginning / initial_stepが指定されていますがresumeしていないため、lr schedulerは最初から始まります"
815+
)
816+
logger.info(f"skipping {initial_step} steps / {initial_step}ステップをスキップします")
817+
initial_step *= args.gradient_accumulation_steps
818+
819+
# set epoch to start to make initial_step less than len(train_dataloader)
820+
epoch_to_start = initial_step // math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
821+
else:
822+
# if not, only epoch no is skipped for informative purpose
823+
epoch_to_start = initial_step // math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
824+
initial_step = 0 # do not skip
825+
761826
global_step = 0
762827

763828
noise_scheduler = DDPMScheduler(
@@ -816,16 +881,31 @@ def remove_model(old_ckpt_name):
816881
self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
817882

818883
# training loop
819-
for epoch in range(num_train_epochs):
884+
if initial_step > 0: # only if skip_until_initial_step is specified
885+
for skip_epoch in range(epoch_to_start): # skip epochs
886+
logger.info(f"skipping epoch {skip_epoch+1} because initial_step (multiplied) is {initial_step}")
887+
initial_step -= len(train_dataloader)
888+
global_step = initial_step
889+
890+
for epoch in range(epoch_to_start, num_train_epochs):
820891
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
821892
current_epoch.value = epoch + 1
822893

823894
metadata["ss_epoch"] = str(epoch + 1)
824895

825896
accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet)
826897

827-
for step, batch in enumerate(train_dataloader):
898+
skipped_dataloader = None
899+
if initial_step > 0:
900+
skipped_dataloader = accelerator.skip_first_batches(train_dataloader, initial_step - 1)
901+
initial_step = 1
902+
903+
for step, batch in enumerate(skipped_dataloader or train_dataloader):
828904
current_step.value = global_step
905+
if initial_step > 0:
906+
initial_step -= 1
907+
continue
908+
829909
with accelerator.accumulate(training_model):
830910
on_step_start(text_encoder, unet)
831911

@@ -1126,6 +1206,25 @@ def setup_parser() -> argparse.ArgumentParser:
11261206
action="store_true",
11271207
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
11281208
)
1209+
parser.add_argument(
1210+
"--skip_until_initial_step",
1211+
action="store_true",
1212+
help="skip training until initial_step is reached / initial_stepに到達するまで学習をスキップする",
1213+
)
1214+
parser.add_argument(
1215+
"--initial_epoch",
1216+
type=int,
1217+
default=None,
1218+
help="initial epoch number, 1 means first epoch (same as not specifying). NOTE: initial_epoch/step doesn't affect to lr scheduler. Which means lr scheduler will start from 0 without `--resume`."
1219+
+ " / 初期エポック数、1で最初のエポック(未指定時と同じ)。注意:initial_epoch/stepはlr schedulerに影響しないため、`--resume`しない場合はlr schedulerは0から始まる",
1220+
)
1221+
parser.add_argument(
1222+
"--initial_step",
1223+
type=int,
1224+
default=None,
1225+
help="initial step number including all epochs, 0 means first step (same as not specifying). overwrites initial_epoch."
1226+
+ " / 初期ステップ数、全エポックを含むステップ数、0で最初のステップ(未指定時と同じ)。initial_epochを上書きする",
1227+
)
11291228
# parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio")
11301229
# parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio")
11311230
# parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio")

0 commit comments

Comments
 (0)