From 8fc30f820595f80ec3f09738cc4cf01f441c41b7 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Mon, 21 Oct 2024 07:34:33 -0400 Subject: [PATCH 1/3] Fix training for V-pred and ztSNR 1) Updates debiased estimation loss function for V-pred. 2) Prevents now-deprecated scaling of loss if ztSNR is enabled. --- fine_tune.py | 4 ++-- library/custom_train_functions.py | 7 +++++-- library/train_util.py | 5 +++++ sdxl_train.py | 4 ++-- sdxl_train_control_net_lllite.py | 4 ++-- sdxl_train_control_net_lllite_old.py | 4 ++-- train_db.py | 4 ++-- train_network.py | 4 ++-- train_textual_inversion.py | 4 ++-- train_textual_inversion_XTI.py | 4 ++-- 10 files changed, 26 insertions(+), 18 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index b556672d2..19a35229f 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -383,10 +383,10 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred: + if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) loss = loss.mean() # mean over batch dimension else: diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 2a513dc5b..faf443048 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -96,10 +96,13 @@ def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_los return loss -def apply_debiased_estimation(loss, timesteps, noise_scheduler): +def apply_debiased_estimation(loss, timesteps, noise_scheduler, v_prediction=False): snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000 - weight = 1 / torch.sqrt(snr_t) + if v_prediction: + weight = 1 / (snr_t + 1) + else: + weight = 1 / torch.sqrt(snr_t) loss = weight * loss return loss diff --git a/library/train_util.py b/library/train_util.py index 27910dc90..adb983d2f 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3731,6 +3731,11 @@ def verify_training_args(args: argparse.Namespace): raise ValueError( "scale_v_pred_loss_like_noise_pred can be enabled only with v_parameterization / scale_v_pred_loss_like_noise_predはv_parameterizationが有効なときのみ有効にできます" ) + + if args.scale_v_pred_loss_like_noise_pred and args.zero_terminal_snr: + raise ValueError( + "zero_terminal_snr enabled. scale_v_pred_loss_like_noise_pred will not be used / zero_terminal_snrが有効です。scale_v_pred_loss_like_noise_predは使用されません" + ) if args.v_pred_like_loss and args.v_parameterization: raise ValueError( diff --git a/sdxl_train.py b/sdxl_train.py index e0a8f2b2e..44ee9233f 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -725,12 +725,12 @@ def optimizer_hook(parameter: torch.Tensor): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred: + if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) loss = loss.mean() # mean over batch dimension else: diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 5ff060a9f..436f0e194 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -474,12 +474,12 @@ def remove_model(old_ckpt_name): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred: + if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index 292a0463a..8fba9eba6 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -434,12 +434,12 @@ def remove_model(old_ckpt_name): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred: + if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし diff --git a/train_db.py b/train_db.py index 2c7f02582..d5a94a565 100644 --- a/train_db.py +++ b/train_db.py @@ -370,10 +370,10 @@ def train(args): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred: + if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし diff --git a/train_network.py b/train_network.py index 044ec3aa8..790fbfc9d 100644 --- a/train_network.py +++ b/train_network.py @@ -993,12 +993,12 @@ def remove_model(old_ckpt_name): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred: + if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 96e7bd509..10b34db5e 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -598,12 +598,12 @@ def remove_model(old_ckpt_name): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred: + if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index efb59137b..084b90c60 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -483,10 +483,10 @@ def remove_model(old_ckpt_name): loss = loss * loss_weights if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred: + if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし From e1b63c2249345e4f14c10cbb252da68157ac13b7 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Mon, 21 Oct 2024 08:12:53 -0400 Subject: [PATCH 2/3] Only add warning for deprecated scaling vpred loss function --- fine_tune.py | 2 +- library/train_util.py | 11 ++++++----- sdxl_train.py | 2 +- sdxl_train_control_net_lllite.py | 2 +- sdxl_train_control_net_lllite_old.py | 2 +- train_db.py | 2 +- train_network.py | 2 +- train_textual_inversion.py | 2 +- train_textual_inversion_XTI.py | 2 +- 9 files changed, 14 insertions(+), 13 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 19a35229f..c79f97d25 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -383,7 +383,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: + if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.debiased_estimation_loss: loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) diff --git a/library/train_util.py b/library/train_util.py index adb983d2f..f479dcc64 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3727,15 +3727,16 @@ def verify_training_args(args: argparse.Namespace): if args.adaptive_noise_scale is not None and args.noise_offset is None: raise ValueError("adaptive_noise_scale requires noise_offset / adaptive_noise_scaleを使用するにはnoise_offsetが必要です") + if args.scale_v_pred_loss_like_noise_pred: + logger.warning( + f"scale_v_pred_loss_like_noise_pred is deprecated. it is suggested to use min_snr_gamma or debiased_estimation_loss" + + " / scale_v_pred_loss_like_noise_pred は非推奨です。min_snr_gammaまたはdebiased_estimation_lossを使用することをお勧めします" + ) + if args.scale_v_pred_loss_like_noise_pred and not args.v_parameterization: raise ValueError( "scale_v_pred_loss_like_noise_pred can be enabled only with v_parameterization / scale_v_pred_loss_like_noise_predはv_parameterizationが有効なときのみ有効にできます" ) - - if args.scale_v_pred_loss_like_noise_pred and args.zero_terminal_snr: - raise ValueError( - "zero_terminal_snr enabled. scale_v_pred_loss_like_noise_pred will not be used / zero_terminal_snrが有効です。scale_v_pred_loss_like_noise_predは使用されません" - ) if args.v_pred_like_loss and args.v_parameterization: raise ValueError( diff --git a/sdxl_train.py b/sdxl_train.py index 44ee9233f..b533b2749 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -725,7 +725,7 @@ def optimizer_hook(parameter: torch.Tensor): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: + if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 436f0e194..0e67cde5c 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -474,7 +474,7 @@ def remove_model(old_ckpt_name): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: + if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index 8fba9eba6..4a01f9e2c 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -434,7 +434,7 @@ def remove_model(old_ckpt_name): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: + if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) diff --git a/train_db.py b/train_db.py index d5a94a565..e7cf3cde3 100644 --- a/train_db.py +++ b/train_db.py @@ -370,7 +370,7 @@ def train(args): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: + if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.debiased_estimation_loss: loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) diff --git a/train_network.py b/train_network.py index 790fbfc9d..7bf125dca 100644 --- a/train_network.py +++ b/train_network.py @@ -993,7 +993,7 @@ def remove_model(old_ckpt_name): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: + if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 10b34db5e..37349da7d 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -598,7 +598,7 @@ def remove_model(old_ckpt_name): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: + if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 084b90c60..fac0787b9 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -483,7 +483,7 @@ def remove_model(old_ckpt_name): loss = loss * loss_weights if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: + if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.debiased_estimation_loss: loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) From 0e7c5929336173e30d7932c0706eaf61a7d396f4 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Tue, 22 Oct 2024 11:19:34 -0400 Subject: [PATCH 3/3] Remove scale_v_pred_loss_like_noise_pred deprecation https://github.com/kohya-ss/sd-scripts/pull/1715#issuecomment-2427876376 --- library/train_util.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index f479dcc64..27910dc90 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3727,12 +3727,6 @@ def verify_training_args(args: argparse.Namespace): if args.adaptive_noise_scale is not None and args.noise_offset is None: raise ValueError("adaptive_noise_scale requires noise_offset / adaptive_noise_scaleを使用するにはnoise_offsetが必要です") - if args.scale_v_pred_loss_like_noise_pred: - logger.warning( - f"scale_v_pred_loss_like_noise_pred is deprecated. it is suggested to use min_snr_gamma or debiased_estimation_loss" - + " / scale_v_pred_loss_like_noise_pred は非推奨です。min_snr_gammaまたはdebiased_estimation_lossを使用することをお勧めします" - ) - if args.scale_v_pred_loss_like_noise_pred and not args.v_parameterization: raise ValueError( "scale_v_pred_loss_like_noise_pred can be enabled only with v_parameterization / scale_v_pred_loss_like_noise_predはv_parameterizationが有効なときのみ有効にできます"