From 1e8108fec9962333e4cf2a8db1dcedf657049900 Mon Sep 17 00:00:00 2001 From: liesen Date: Sat, 24 Aug 2024 01:38:17 +0300 Subject: [PATCH] Handle args.v_parameterization properly for MinSNR and changed prediction target --- sdxl_train.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/sdxl_train.py b/sdxl_train.py index 46d7860be..14b259657 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -590,7 +590,11 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): with accelerator.autocast(): noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) - target = noise + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise if ( args.min_snr_gamma @@ -606,7 +610,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): loss = loss.mean([1, 2, 3]) if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: