Skip to content

Commit c178ea6

Browse files
committed
Revert "update"
This reverts commit 1900382.
1 parent 1900382 commit c178ea6

14 files changed

+9
-1121
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ Empirically, **Restart sampler surpasses previous diffusion SDE and ODE samplers
1616

1717
Results on [Stable Diffusion v1.5](https://github.com/huggingface/diffusers):
1818

19-
![restart-min](/Users/aaronxu/Desktop/Experiment/github_repos/restart/assets/restart-min.gif)
19+
![schematic](assets/vis.png)
2020

2121
---
2222

assets/.DS_Store

0 Bytes
Binary file not shown.

assets/restart-min.gif

-9.91 MB
Binary file not shown.

diffuser/.DS_Store

0 Bytes
Binary file not shown.

diffuser/diffusers/.DS_Store

0 Bytes
Binary file not shown.

diffuser/diffusers/__init__.py

-1
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,6 @@
131131
StableDiffusionModelEditingPipeline,
132132
StableDiffusionPanoramaPipeline,
133133
StableDiffusionPipeline,
134-
StableDiffusionParticlePipeline,
135134
StableDiffusionPipelineSafe,
136135
StableDiffusionPix2PixZeroPipeline,
137136
StableDiffusionSAGPipeline,
0 Bytes
Binary file not shown.

diffuser/diffusers/pipelines/__init__.py

-1
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@
5555
StableDiffusionImageVariationPipeline,
5656
StableDiffusionImg2ImgPipeline,
5757
StableDiffusionInpaintPipeline,
58-
StableDiffusionParticlePipeline,
5958
StableDiffusionInpaintPipelineLegacy,
6059
StableDiffusionInstructPix2PixPipeline,
6160
StableDiffusionLatentUpscalePipeline,

diffuser/diffusers/pipelines/stable_diffusion/__init__.py

-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ class StableDiffusionPipelineOutput(BaseOutput):
4444
else:
4545
from .pipeline_cycle_diffusion import CycleDiffusionPipeline
4646
from .pipeline_stable_diffusion import StableDiffusionPipeline
47-
from .pipeline_stable_diffusion_particle import StableDiffusionParticlePipeline
4847
from .pipeline_stable_diffusion_attend_and_excite import StableDiffusionAttendAndExcitePipeline
4948
from .pipeline_stable_diffusion_controlnet import StableDiffusionControlNetPipeline
5049
from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline

diffuser/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

+5-105
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
from . import StableDiffusionPipelineOutput
3636
from .safety_checker import StableDiffusionSafetyChecker
3737
import copy
38-
import numpy as np
3938

4039
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
4140

@@ -673,13 +672,9 @@ def __call__(
673672
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
674673

675674

676-
677675
restart_list = {}
678-
all_sigmas = torch.sqrt((1 - self.scheduler.alphas_cumprod) / self.scheduler.alphas_cumprod)
679-
sigma_max = all_sigmas.max()
680-
681676
if restart:
682-
#all_sigmas = torch.sqrt((1 - self.scheduler.alphas_cumprod) / self.scheduler.alphas_cumprod)
677+
all_sigmas = torch.sqrt((1 - self.scheduler.alphas_cumprod) / self.scheduler.alphas_cumprod)
683678
sigmas = all_sigmas[timesteps.cpu().numpy()]
684679

685680
# {t_min: [N_restart, K, t_max], ... }
@@ -704,11 +699,6 @@ def __call__(
704699
# 7. Denoising loop
705700
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
706701
image_list = []
707-
pre_noise_pred_text = None
708-
pre_pred_original_sample = None
709-
pre_2_noise_pred_text = None
710-
history_noise_pred_text = torch.zeros((len(timesteps[:-1]), * latents.shape)).to(latents.device)
711-
pre_latent = None
712702
with self.progress_bar(total=num_inference_steps) as progress_bar:
713703
for i, t in enumerate(timesteps[:-1]):
714704
# expand the latents if we are doing classifier free guidance
@@ -726,77 +716,10 @@ def __call__(
726716
# perform guidance
727717
if do_classifier_free_guidance:
728718
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
729-
#noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
730-
731-
guide_idx = 5
732-
733-
if i > guide_idx:
734-
# In the first 10 steps, we still follow the conventional cfg (I will explain the reason)
735-
pre_noise_pred_text = history_noise_pred_text[i-1]
736-
# new_d = pre_noise_pred_text \
737-
# * noise_pred_text.norm(p=2, dim=[1,2,3], keepdim=True)\
738-
# / pre_noise_pred_text.norm(p=2, dim=[1,2,3], keepdim=True)
739-
# if i < 10:
740-
# new_d = pre_noise_pred_text - noise_pred_text
741-
# else:
742-
# new_d = noise_pred_text - pre_noise_pred_text
743-
744-
pre_noise_pred_text = pre_noise_pred_text * noise_pred_text.norm(p=2, dim=[1, 2, 3],
745-
keepdim=True).mean() / pre_noise_pred_text.norm(p=2,
746-
dim=[1,2,3],
747-
keepdim=True).mean()
748-
new_d = noise_pred_text - pre_noise_pred_text
749-
750-
# # new_off_set = noise_pred_text - new_d
751-
# old_off_set = noise_pred_text - noise_pred_uncond
752-
# #
753-
# noise_pred_text_vec = noise_pred_text.view(len(noise_pred_text), -1)
754-
# pre_noise_pred_text_vec = pre_noise_pred_text.view(len(pre_noise_pred_text), -1)
755-
# pre_noise_pred_text_vec_2 = history_noise_pred_text[i-2].view(len(pre_noise_pred_text), -1)
756-
# noise_pred_uncond_vec = noise_pred_uncond.view(len(noise_pred_uncond), -1)
757-
#
758-
# for num in range(len(noise_pred_text)):
759-
# A = torch.stack([noise_pred_text_vec[num],
760-
# pre_noise_pred_text_vec[num]]).transpose(0,1)
761-
# B = (noise_pred_text_vec[num] - noise_pred_uncond_vec[num]).view(-1, 1)
762-
#
763-
# X = torch.linalg.lstsq(A, B).solution
764-
#print(f"num:{num}, X:{X}")
765-
# new_d[num] = noise_pred_text[num] * X[0] + \
766-
# pre_noise_pred_text[num] * X[1]
767-
768-
769-
# compute cosine similarity
770-
# orthogonal = old_off_set - new_d
771-
# orthogonal_nor = orthogonal / orthogonal.norm(p=2, dim=[1,2,3], keepdim=True)
772-
# new_off_set_nor = new_d / new_d.norm(p=2, dim=[1,2,3], keepdim=True)
773-
# old_off_set_nor = old_off_set / old_off_set.norm(p=2, dim=[1, 2, 3], keepdim=True)
774-
#
775-
# print(i, sigma, (new_off_set_nor * old_off_set_nor).sum(dim=[1,2,3]),
776-
# (new_off_set_nor * orthogonal_nor).sum(dim=[1,2,3]))
777-
778-
# pre_pred_original_sample = pred_original_sample
779-
# pre_latent = latents
780-
781-
#noise_pred = noise_pred_text + guidance_scale * (noise_pred_text - noise_pred_uncond)
782-
sigma = all_sigmas[t.cpu().numpy()]
783-
sigma = sigma.to(noise_pred_text.device)
784-
print(i, sigma)
785-
786-
if i > guide_idx:
787-
scaling = sigma / 6
788-
# scaling = 1
789-
# scaling = 0
790-
noise_pred = noise_pred_text + scaling * guidance_scale * (new_d)
791-
else:
792-
noise_pred = noise_pred_text + guidance_scale * (noise_pred_text - noise_pred_uncond)
793-
794-
history_noise_pred_text[i] = noise_pred_text
719+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
795720

796721
# compute the previous noisy sample x_t -> x_t-1
797-
temp = 1 if i > guide_idx else 1
798-
sde = False if i > guide_idx else False
799-
latents_next = self.scheduler.step(noise_pred, i, latents, temp=temp, sde=sde, **extra_step_kwargs).prev_sample
722+
latents_next = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample
800723

801724
# Apply 2nd order correction (Heun).
802725
if second_order and i < len(timesteps) - 2:
@@ -844,7 +767,7 @@ def __call__(
844767
print("restart steps:", new_t_steps)
845768

846769
latents = self.scheduler.add_noise_between_t(latents, new_t_steps[-1], new_t_steps[0], generator, S_noise=S_noise)
847-
history_noise_pred_text_restart = torch.zeros((len(new_t_steps[:-1]), *latents.shape)).to(latents.device)
770+
848771
for j, new_t in enumerate(new_t_steps[:-1]):
849772
# print(" restart:", new_t)
850773
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
@@ -859,35 +782,12 @@ def __call__(
859782
).sample
860783

861784
# perform guidance
862-
guide_idx_restart = 0
863785
if do_classifier_free_guidance:
864-
865786
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
866-
if j > guide_idx_restart:
867-
# In the first 10 steps, we still follow the conventional cfg (I will explain the reason)
868-
pre_noise_pred_text = history_noise_pred_text_restart[j - 1]
869-
pre_noise_pred_text = pre_noise_pred_text * noise_pred_text.norm(p=2, dim=[1, 2, 3],
870-
keepdim=True).mean() / pre_noise_pred_text.norm(
871-
p=2,
872-
dim=[1, 2, 3],
873-
keepdim=True).mean()
874-
new_d = noise_pred_text - pre_noise_pred_text
875-
876-
877-
if j > guide_idx_restart:
878-
scaling = sigma / 6
879-
# scaling = 1
880-
# scaling = 0
881-
noise_pred = noise_pred_text + scaling * guidance_scale * (new_d)
882-
else:
883-
noise_pred = noise_pred_text + guidance_scale * (
884-
noise_pred_text - noise_pred_uncond)
885-
886-
#noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
787+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
887788

888789
# compute the previous noisy sample x_t -> x_t-1
889790
latents_next = new_scheduler.step(noise_pred, j, latents, **extra_step_kwargs).prev_sample
890-
history_noise_pred_text_restart[j] = noise_pred_text
891791

892792
# Apply 2nd order correction.
893793
if (j < len(new_t_steps) - 2 or new_t_steps[-1] > 1):

0 commit comments

Comments
 (0)