35
35
from . import StableDiffusionPipelineOutput
36
36
from .safety_checker import StableDiffusionSafetyChecker
37
37
import copy
38
- import numpy as np
39
38
40
39
logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
41
40
@@ -673,13 +672,9 @@ def __call__(
673
672
extra_step_kwargs = self .prepare_extra_step_kwargs (generator , eta )
674
673
675
674
676
-
677
675
restart_list = {}
678
- all_sigmas = torch .sqrt ((1 - self .scheduler .alphas_cumprod ) / self .scheduler .alphas_cumprod )
679
- sigma_max = all_sigmas .max ()
680
-
681
676
if restart :
682
- # all_sigmas = torch.sqrt((1 - self.scheduler.alphas_cumprod) / self.scheduler.alphas_cumprod)
677
+ all_sigmas = torch .sqrt ((1 - self .scheduler .alphas_cumprod ) / self .scheduler .alphas_cumprod )
683
678
sigmas = all_sigmas [timesteps .cpu ().numpy ()]
684
679
685
680
# {t_min: [N_restart, K, t_max], ... }
@@ -704,11 +699,6 @@ def __call__(
704
699
# 7. Denoising loop
705
700
num_warmup_steps = len (timesteps ) - num_inference_steps * self .scheduler .order
706
701
image_list = []
707
- pre_noise_pred_text = None
708
- pre_pred_original_sample = None
709
- pre_2_noise_pred_text = None
710
- history_noise_pred_text = torch .zeros ((len (timesteps [:- 1 ]), * latents .shape )).to (latents .device )
711
- pre_latent = None
712
702
with self .progress_bar (total = num_inference_steps ) as progress_bar :
713
703
for i , t in enumerate (timesteps [:- 1 ]):
714
704
# expand the latents if we are doing classifier free guidance
@@ -726,77 +716,10 @@ def __call__(
726
716
# perform guidance
727
717
if do_classifier_free_guidance :
728
718
noise_pred_uncond , noise_pred_text = noise_pred .chunk (2 )
729
- #noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
730
-
731
- guide_idx = 5
732
-
733
- if i > guide_idx :
734
- # In the first 10 steps, we still follow the conventional cfg (I will explain the reason)
735
- pre_noise_pred_text = history_noise_pred_text [i - 1 ]
736
- # new_d = pre_noise_pred_text \
737
- # * noise_pred_text.norm(p=2, dim=[1,2,3], keepdim=True)\
738
- # / pre_noise_pred_text.norm(p=2, dim=[1,2,3], keepdim=True)
739
- # if i < 10:
740
- # new_d = pre_noise_pred_text - noise_pred_text
741
- # else:
742
- # new_d = noise_pred_text - pre_noise_pred_text
743
-
744
- pre_noise_pred_text = pre_noise_pred_text * noise_pred_text .norm (p = 2 , dim = [1 , 2 , 3 ],
745
- keepdim = True ).mean () / pre_noise_pred_text .norm (p = 2 ,
746
- dim = [1 ,2 ,3 ],
747
- keepdim = True ).mean ()
748
- new_d = noise_pred_text - pre_noise_pred_text
749
-
750
- # # new_off_set = noise_pred_text - new_d
751
- # old_off_set = noise_pred_text - noise_pred_uncond
752
- # #
753
- # noise_pred_text_vec = noise_pred_text.view(len(noise_pred_text), -1)
754
- # pre_noise_pred_text_vec = pre_noise_pred_text.view(len(pre_noise_pred_text), -1)
755
- # pre_noise_pred_text_vec_2 = history_noise_pred_text[i-2].view(len(pre_noise_pred_text), -1)
756
- # noise_pred_uncond_vec = noise_pred_uncond.view(len(noise_pred_uncond), -1)
757
- #
758
- # for num in range(len(noise_pred_text)):
759
- # A = torch.stack([noise_pred_text_vec[num],
760
- # pre_noise_pred_text_vec[num]]).transpose(0,1)
761
- # B = (noise_pred_text_vec[num] - noise_pred_uncond_vec[num]).view(-1, 1)
762
- #
763
- # X = torch.linalg.lstsq(A, B).solution
764
- #print(f"num:{num}, X:{X}")
765
- # new_d[num] = noise_pred_text[num] * X[0] + \
766
- # pre_noise_pred_text[num] * X[1]
767
-
768
-
769
- # compute cosine similarity
770
- # orthogonal = old_off_set - new_d
771
- # orthogonal_nor = orthogonal / orthogonal.norm(p=2, dim=[1,2,3], keepdim=True)
772
- # new_off_set_nor = new_d / new_d.norm(p=2, dim=[1,2,3], keepdim=True)
773
- # old_off_set_nor = old_off_set / old_off_set.norm(p=2, dim=[1, 2, 3], keepdim=True)
774
- #
775
- # print(i, sigma, (new_off_set_nor * old_off_set_nor).sum(dim=[1,2,3]),
776
- # (new_off_set_nor * orthogonal_nor).sum(dim=[1,2,3]))
777
-
778
- # pre_pred_original_sample = pred_original_sample
779
- # pre_latent = latents
780
-
781
- #noise_pred = noise_pred_text + guidance_scale * (noise_pred_text - noise_pred_uncond)
782
- sigma = all_sigmas [t .cpu ().numpy ()]
783
- sigma = sigma .to (noise_pred_text .device )
784
- print (i , sigma )
785
-
786
- if i > guide_idx :
787
- scaling = sigma / 6
788
- # scaling = 1
789
- # scaling = 0
790
- noise_pred = noise_pred_text + scaling * guidance_scale * (new_d )
791
- else :
792
- noise_pred = noise_pred_text + guidance_scale * (noise_pred_text - noise_pred_uncond )
793
-
794
- history_noise_pred_text [i ] = noise_pred_text
719
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond )
795
720
796
721
# compute the previous noisy sample x_t -> x_t-1
797
- temp = 1 if i > guide_idx else 1
798
- sde = False if i > guide_idx else False
799
- latents_next = self .scheduler .step (noise_pred , i , latents , temp = temp , sde = sde , ** extra_step_kwargs ).prev_sample
722
+ latents_next = self .scheduler .step (noise_pred , i , latents , ** extra_step_kwargs ).prev_sample
800
723
801
724
# Apply 2nd order correction (Heun).
802
725
if second_order and i < len (timesteps ) - 2 :
@@ -844,7 +767,7 @@ def __call__(
844
767
print ("restart steps:" , new_t_steps )
845
768
846
769
latents = self .scheduler .add_noise_between_t (latents , new_t_steps [- 1 ], new_t_steps [0 ], generator , S_noise = S_noise )
847
- history_noise_pred_text_restart = torch . zeros (( len ( new_t_steps [: - 1 ]), * latents . shape )). to ( latents . device )
770
+
848
771
for j , new_t in enumerate (new_t_steps [:- 1 ]):
849
772
# print(" restart:", new_t)
850
773
latent_model_input = torch .cat ([latents ] * 2 ) if do_classifier_free_guidance else latents
@@ -859,35 +782,12 @@ def __call__(
859
782
).sample
860
783
861
784
# perform guidance
862
- guide_idx_restart = 0
863
785
if do_classifier_free_guidance :
864
-
865
786
noise_pred_uncond , noise_pred_text = noise_pred .chunk (2 )
866
- if j > guide_idx_restart :
867
- # In the first 10 steps, we still follow the conventional cfg (I will explain the reason)
868
- pre_noise_pred_text = history_noise_pred_text_restart [j - 1 ]
869
- pre_noise_pred_text = pre_noise_pred_text * noise_pred_text .norm (p = 2 , dim = [1 , 2 , 3 ],
870
- keepdim = True ).mean () / pre_noise_pred_text .norm (
871
- p = 2 ,
872
- dim = [1 , 2 , 3 ],
873
- keepdim = True ).mean ()
874
- new_d = noise_pred_text - pre_noise_pred_text
875
-
876
-
877
- if j > guide_idx_restart :
878
- scaling = sigma / 6
879
- # scaling = 1
880
- # scaling = 0
881
- noise_pred = noise_pred_text + scaling * guidance_scale * (new_d )
882
- else :
883
- noise_pred = noise_pred_text + guidance_scale * (
884
- noise_pred_text - noise_pred_uncond )
885
-
886
- #noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
787
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond )
887
788
888
789
# compute the previous noisy sample x_t -> x_t-1
889
790
latents_next = new_scheduler .step (noise_pred , j , latents , ** extra_step_kwargs ).prev_sample
890
- history_noise_pred_text_restart [j ] = noise_pred_text
891
791
892
792
# Apply 2nd order correction.
893
793
if (j < len (new_t_steps ) - 2 or new_t_steps [- 1 ] > 1 ):
0 commit comments