@@ -481,6 +481,26 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
481
481
text_encoder2 = accelerator .prepare (text_encoder2 )
482
482
optimizer , train_dataloader , lr_scheduler = accelerator .prepare (optimizer , train_dataloader , lr_scheduler )
483
483
484
+ # TextEncoderの出力をキャッシュするときにはCPUへ移動する
485
+ if args .cache_text_encoder_outputs :
486
+ # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
487
+ text_encoder1 .to ("cpu" , dtype = torch .float32 )
488
+ text_encoder2 .to ("cpu" , dtype = torch .float32 )
489
+ clean_memory_on_device (accelerator .device )
490
+ else :
491
+ # make sure Text Encoders are on GPU
492
+ text_encoder1 .to (accelerator .device )
493
+ text_encoder2 .to (accelerator .device )
494
+
495
+ # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
496
+ if args .full_fp16 :
497
+ # During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do.
498
+ # -> But we think it's ok to patch accelerator even if deepspeed is enabled.
499
+ train_util .patch_accelerator_for_fp16_training (accelerator )
500
+
501
+ # resumeする
502
+ train_util .resume_from_local_or_hf_if_specified (accelerator , args )
503
+
484
504
if args .fused_backward_pass :
485
505
# use fused optimizer for backward pass: other optimizers will be supported in the future
486
506
import library .adafactor_fused
@@ -532,26 +552,6 @@ def optimizer_hook(parameter: torch.Tensor):
532
552
parameter_optimizer_map [parameter ] = opt_idx
533
553
num_parameters_per_group [opt_idx ] += 1
534
554
535
- # TextEncoderの出力をキャッシュするときにはCPUへ移動する
536
- if args .cache_text_encoder_outputs :
537
- # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
538
- text_encoder1 .to ("cpu" , dtype = torch .float32 )
539
- text_encoder2 .to ("cpu" , dtype = torch .float32 )
540
- clean_memory_on_device (accelerator .device )
541
- else :
542
- # make sure Text Encoders are on GPU
543
- text_encoder1 .to (accelerator .device )
544
- text_encoder2 .to (accelerator .device )
545
-
546
- # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
547
- if args .full_fp16 :
548
- # During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do.
549
- # -> But we think it's ok to patch accelerator even if deepspeed is enabled.
550
- train_util .patch_accelerator_for_fp16_training (accelerator )
551
-
552
- # resumeする
553
- train_util .resume_from_local_or_hf_if_specified (accelerator , args )
554
-
555
555
# epoch数を計算する
556
556
num_update_steps_per_epoch = math .ceil (len (train_dataloader ) / args .gradient_accumulation_steps )
557
557
num_train_epochs = math .ceil (args .max_train_steps / num_update_steps_per_epoch )
@@ -589,7 +589,11 @@ def optimizer_hook(parameter: torch.Tensor):
589
589
init_kwargs ["wandb" ] = {"name" : args .wandb_run_name }
590
590
if args .log_tracker_config is not None :
591
591
init_kwargs = toml .load (args .log_tracker_config )
592
- accelerator .init_trackers ("finetuning" if args .log_tracker_name is None else args .log_tracker_name , config = train_util .get_sanitized_config_or_none (args ), init_kwargs = init_kwargs )
592
+ accelerator .init_trackers (
593
+ "finetuning" if args .log_tracker_name is None else args .log_tracker_name ,
594
+ config = train_util .get_sanitized_config_or_none (args ),
595
+ init_kwargs = init_kwargs ,
596
+ )
593
597
594
598
# For --sample_at_first
595
599
sdxl_train_util .sample_images (
0 commit comments