Skip to content

Commit 56bb81c

Browse files
committed
add grad_hook after restore state closes kohya-ss#1344
1 parent 22413a5 commit 56bb81c

File tree

1 file changed

+25
-21
lines changed

1 file changed

+25
-21
lines changed

sdxl_train.py

+25-21
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,26 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
481481
text_encoder2 = accelerator.prepare(text_encoder2)
482482
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
483483

484+
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
485+
if args.cache_text_encoder_outputs:
486+
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
487+
text_encoder1.to("cpu", dtype=torch.float32)
488+
text_encoder2.to("cpu", dtype=torch.float32)
489+
clean_memory_on_device(accelerator.device)
490+
else:
491+
# make sure Text Encoders are on GPU
492+
text_encoder1.to(accelerator.device)
493+
text_encoder2.to(accelerator.device)
494+
495+
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
496+
if args.full_fp16:
497+
# During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do.
498+
# -> But we think it's ok to patch accelerator even if deepspeed is enabled.
499+
train_util.patch_accelerator_for_fp16_training(accelerator)
500+
501+
# resumeする
502+
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
503+
484504
if args.fused_backward_pass:
485505
# use fused optimizer for backward pass: other optimizers will be supported in the future
486506
import library.adafactor_fused
@@ -532,26 +552,6 @@ def optimizer_hook(parameter: torch.Tensor):
532552
parameter_optimizer_map[parameter] = opt_idx
533553
num_parameters_per_group[opt_idx] += 1
534554

535-
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
536-
if args.cache_text_encoder_outputs:
537-
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
538-
text_encoder1.to("cpu", dtype=torch.float32)
539-
text_encoder2.to("cpu", dtype=torch.float32)
540-
clean_memory_on_device(accelerator.device)
541-
else:
542-
# make sure Text Encoders are on GPU
543-
text_encoder1.to(accelerator.device)
544-
text_encoder2.to(accelerator.device)
545-
546-
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
547-
if args.full_fp16:
548-
# During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do.
549-
# -> But we think it's ok to patch accelerator even if deepspeed is enabled.
550-
train_util.patch_accelerator_for_fp16_training(accelerator)
551-
552-
# resumeする
553-
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
554-
555555
# epoch数を計算する
556556
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
557557
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
@@ -589,7 +589,11 @@ def optimizer_hook(parameter: torch.Tensor):
589589
init_kwargs["wandb"] = {"name": args.wandb_run_name}
590590
if args.log_tracker_config is not None:
591591
init_kwargs = toml.load(args.log_tracker_config)
592-
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs)
592+
accelerator.init_trackers(
593+
"finetuning" if args.log_tracker_name is None else args.log_tracker_name,
594+
config=train_util.get_sanitized_config_or_none(args),
595+
init_kwargs=init_kwargs,
596+
)
593597

594598
# For --sample_at_first
595599
sdxl_train_util.sample_images(

0 commit comments

Comments
 (0)