Skip to content

Commit 63b7f01

Browse files
committed
Improve consistency models sampling implementation.
1 parent 799ab23 commit 63b7f01

File tree

2 files changed

+126
-33
lines changed

2 files changed

+126
-33
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .pipeline_consistency_models import ConsistencyModelPipeline

src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py

+125-33
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import inspect
2-
from typing import List, Optional, Tuple, Union
2+
from typing import List, Optional, Tuple, Union, Callable
33

44
import torch
55

@@ -8,6 +8,17 @@
88
from ...utils import randn_tensor
99
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
1010

11+
12+
def append_dims(x, target_dims):
13+
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
14+
dims_to_append = target_dims - x.ndim
15+
if dims_to_append < 0:
16+
raise ValueError(
17+
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
18+
)
19+
return x[(...,) + (None,) * dims_to_append]
20+
21+
1122
class ConsistencyModelPipeline(DiffusionPipeline):
1223
r"""
1324
TODO
@@ -40,30 +51,76 @@ def prepare_extra_step_kwargs(self, generator, eta):
4051
extra_step_kwargs["generator"] = generator
4152
return extra_step_kwargs
4253

54+
def get_scalings(self, sigma, sigma_data: float = 0.5):
55+
c_skip = sigma_data**2 / (sigma**2 + sigma_data**2)
56+
c_out = sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5
57+
c_in = 1 / (sigma**2 + sigma_data**2) ** 0.5
58+
return c_skip, c_out, c_in
59+
60+
def get_scalings_for_boundary_condition(sigma, sigma_min, sigma_data: float = 0.5):
61+
c_skip = sigma_data**2 / (
62+
(sigma - sigma_min) ** 2 + sigma_data**2
63+
)
64+
c_out = (
65+
(sigma - sigma_min)
66+
* sigma_data
67+
/ (sigma**2 + sigma_data**2) ** 0.5
68+
)
69+
c_in = 1 / (sigma**2 + sigma_data**2) ** 0.5
70+
return c_skip, c_out, c_in
71+
72+
def denoise(self, x_t, sigma, sigma_min, sigma_data: float = 0.5, clip_denoised=True):
73+
"""
74+
Run the consistency model forward...?
75+
"""
76+
c_skip, c_out, c_in = [
77+
append_dims(x, x_t.ndim)
78+
for x in self.get_scalings_for_boundary_condition(sigma, sigma_min, sigma_data=sigma_data)
79+
]
80+
rescaled_t = 1000 * 0.25 * torch.log(sigma + 1e-44)
81+
model_output = self.unet(c_in * x_t, rescaled_t).sample
82+
denoised = c_out * model_output + c_skip * x_t
83+
if clip_denoised:
84+
denoised = denoised.clamp(-1, 1)
85+
return model_output, denoised
86+
87+
def to_d(x, sigma, denoised):
88+
"""Converts a denoiser output to a Karras ODE derivative."""
89+
return (x - denoised) / append_dims(sigma, x.ndim)
90+
4391
def add_noise_to_input(
44-
self,
45-
sample: torch.FloatTensor,
46-
generator: Optional[torch.Generator] = None,
47-
step: int = 0
48-
):
49-
"""
50-
Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a
51-
higher noise level sigma_hat = sigma_i + gamma_i*sigma_i.
52-
TODO Args:
53-
"""
54-
pass
92+
self,
93+
sample: torch.FloatTensor,
94+
sigma_hat: float,
95+
sigma_min: float,
96+
sigma_max: float,
97+
s_noise: float = 1.0,
98+
generator: Optional[torch.Generator] = None,
99+
):
100+
# Clamp sigma_hat
101+
sigma_hat = sigma_hat.clamp(min=sigma_min, max=sigma_max)
55102

103+
# sample z ~ N(0, s_noise^2 * I)
104+
z = s_noise * randn_tensor(sample.shape, generator=generator, device=sample.device)
105+
106+
# tau = sigma_hat; eps = sigma_min
107+
sample_hat = sample + ((sigma_hat**2 - sigma_min**2) ** 0.5 * z)
108+
109+
return sample_hat
56110

57111
@torch.no_grad()
58112
def __call__(
59113
self,
60114
batch_size: int = 1,
61-
num_inference_steps: int = 2000,
115+
num_inference_steps: int = 40,
116+
clip_denoised: bool = True,
117+
sigma_data: float = 0.5,
62118
eta: float = 0.0,
63119
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
64120
output_type: Optional[str] = "pil",
65121
return_dict: bool = True,
66-
**kwargs,
122+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
123+
callback_steps: int = 1,
67124
):
68125
r"""
69126
Args:
@@ -87,33 +144,72 @@ def __call__(
87144
img_size = img_size = self.unet.config.sample_size
88145
shape = (batch_size, 3, img_size, img_size)
89146
device = self.device
147+
scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas")
148+
scheduler_has_sigma_min = hasattr(self.scheduler, "sigma_min")
149+
assert scheduler_has_sigma_min or scheduler_is_in_sigma_space, "Scheduler needs to have sigmas"
90150

91151
# 1. Sample image latents x_0 ~ N(0, sigma_0^2 * I)
92152
sample = randn_tensor(shape, generator=generator, device=device) * self.scheduler.init_noise_sigma
93153

94-
# 2. Set timesteps
154+
# 2. Set timesteps and get sigmas
95155
self.scheduler.set_timesteps(num_inference_steps)
96-
# TODO: should schedulers always have sigmas? I think the original code always uses sigmas
97-
# self.scheduler.set_sigmas(num_inference_steps)
156+
timesteps = self.scheduler.timesteps
98157

99158
# 3. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
100159
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
101160

102161
# 4. Denoising loop
103-
# num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
104-
with self.progress_bar(total=num_inference_steps) as progress_bar:
105-
for i, t in enumerate(self.scheduler.timesteps):
106-
# TODO: handle class labels?
107-
model_output = self.unet(sample, t)
108-
109-
sample = self.scheduler.step(model_output, t, sample, **extra_step_kwargs).prev_sample
110-
111-
# TODO: need to handle karras sigma stuff here?
112-
113-
# TODO: need to support callbacks?
162+
if scheduler_has_sigma_min:
163+
# 4.1 Scheduler which can add noise to input (e.g. KarrasVeScheduler)
164+
sigma_min = self.scheduler.sigma_min
165+
sigma_max = self.scheduler.sigma_max
166+
s_noise = self.scheduler.s_noise
167+
sigmas = self.scheduler.schedule
168+
169+
# First evaluate the consistency model. This will be the output sample if num_inference_steps == 1
170+
sigma = sigmas[timesteps[0]]
171+
_, sample = self.denoise(sample, sigma_min, sigma_data=sigma_data, clip_denoised=clip_denoised)
172+
173+
# If num_inference_steps > 1, perform multi-step sampling (stochastic_iterative_sampler)
174+
# Alternate adding noise and evaluating the consistency model
175+
for i, t in self.progress_bar(enumerate(self.scheduler.timesteps[1:])):
176+
sigma = sigmas[t]
177+
sigma_prev = sigmas[t - 1]
178+
if hasattr(self.scheduler, "add_noise_to_input"):
179+
sample_hat = self.scheduler.add_noise_to_input(sample, sigma, generator=generator)[0]
180+
else:
181+
sample_hat = self.add_noise_to_input(sample, sigma, sigma_prev, sigma_min, sigma_max, s_noise=s_noise, generator=generator)
182+
183+
_, sample = self.denoise(sample_hat, sigma, sigma_min, sigma_data=sigma_data, clip_denoised=clip_denoised)
184+
else:
185+
# 4.2 Karras-style scheduler in sigma space (e.g. HeunDiscreteScheduler)
186+
sigma_min = self.scheduler.sigmas[-1]
187+
# TODO: warmup steps logic correct?
188+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
189+
with self.progress_bar(total=num_inference_steps) as progress_bar:
190+
for i, t in enumerate(timesteps):
191+
step_idx = self.scheduler.index_for_timestep(t)
192+
sigma = self.scheduler.sigmas[step_idx]
193+
# TODO: handle class labels?
194+
model_output, denoised = self.denoise(
195+
sample, sigma, sigma_min, sigma_data=sigma_data, clip_denoised=clip_denoised
196+
)
197+
198+
# Karras-style schedulers already convert to a ODE derivative inside step()
199+
sample = self.scheduler.step(denoised, t, sample, **extra_step_kwargs).prev_sample
200+
201+
# TODO: need to handle karras sigma stuff here?
202+
203+
# TODO: differs from callback support in original code
204+
# See e.g. https://github.com/openai/consistency_models/blob/main/cm/karras_diffusion.py#L459
205+
# call the callback, if provided
206+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
207+
progress_bar.update()
208+
if callback is not None and i % callback_steps == 0:
209+
callback(i, t, sample)
114210

115211
# 5. Post-process image sample
116-
sample = sample.clamp(0, 1)
212+
sample = (sample / 2 + 0.5).clamp(0, 1)
117213
sample = sample.cpu().permute(0, 2, 3, 1).numpy()
118214

119215
if output_type == "pil":
@@ -125,7 +221,3 @@ def __call__(
125221
# TODO: Offload to cpu?
126222

127223
return ImagePipelineOutput(images=sample)
128-
129-
130-
131-

0 commit comments

Comments
 (0)