Skip to content

Commit f2e53da

Browse files
committed
Add CMStochasticIterativeScheduler, which implements the multi-step sampler (stochastic_iterative_sampler) in the original code, and make further improvements to sampling.
1 parent 63b7f01 commit f2e53da

File tree

2 files changed

+279
-45
lines changed

2 files changed

+279
-45
lines changed

src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py

+73-45
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,52 @@ def prepare_extra_step_kwargs(self, generator, eta):
5151
extra_step_kwargs["generator"] = generator
5252
return extra_step_kwargs
5353

54+
def get_sigma_min_max_from_scheduler(self):
55+
# Get sigma_min, sigma_max in original sigma space, not Karras sigma space
56+
# (e.g. not exponentiated by 1 / rho)
57+
if hasattr(self.scheduler, "sigma_min"):
58+
sigma_min = self.scheduler.sigma_min
59+
sigma_max = self.scheduler.sigma_max
60+
elif hasattr(self.scheduler, "sigmas"):
61+
# Karras-style scheduler e.g. (EulerDiscreteScheduler, HeunDiscreteScheduler)
62+
# Get sigma_min, sigma_max before they're converted into Karras sigma space by set_timesteps
63+
# TODO: Karras schedulers are inconsistent about how they initialize sigmas in __init__
64+
# For example, EulerDiscreteScheduler gets sigmas in original sigma space, but HeunDiscreteScheduler
65+
# initializes it through set_timesteps, which potentially leaves the sigmas in Karras sigma space.
66+
# TODO: For example, in EulerDiscreteScheduler, a value of 0.0 is appended to the sigmas whern initialized
67+
# in __init__. But wouldn't we usually want sigma_min to be a small positive number, following the
68+
# consistency models paper?
69+
# See e.g. https://github.com/openai/consistency_models/blob/main/scripts/launch.sh#L13
70+
sigma_min = self.scheduler.sigmas[-1].item()
71+
sigma_max = self.scheduler.sigmas[0].item()
72+
else:
73+
raise ValueError(
74+
f"Scheduler {self.scheduler.__class__} does not have sigma_min or sigma_max."
75+
)
76+
return sigma_min, sigma_max
77+
78+
def get_sigmas_from_scheduler(self):
79+
if hasattr(self.scheduler, "sigmas"):
80+
# e.g. HeunDiscreteScheduler
81+
sigmas = self.scheduler.sigmas
82+
elif hasattr(self.scheduler, "schedule"):
83+
# e.g. KarrasVeScheduler
84+
sigmas = self.scheduler.schedule
85+
else:
86+
raise ValueError(
87+
f"Scheduler {self.scheduler.__class__} does not have sigmas."
88+
)
89+
return sigmas
90+
5491
def get_scalings(self, sigma, sigma_data: float = 0.5):
5592
c_skip = sigma_data**2 / (sigma**2 + sigma_data**2)
5693
c_out = sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5
5794
c_in = 1 / (sigma**2 + sigma_data**2) ** 0.5
5895
return c_skip, c_out, c_in
5996

6097
def get_scalings_for_boundary_condition(sigma, sigma_min, sigma_data: float = 0.5):
98+
# sigma_min should be in original sigma space, not in karras sigma space
99+
# (e.g. not exponentiated by 1 / rho)
61100
c_skip = sigma_data**2 / (
62101
(sigma - sigma_min) ** 2 + sigma_data**2
63102
)
@@ -73,6 +112,8 @@ def denoise(self, x_t, sigma, sigma_min, sigma_data: float = 0.5, clip_denoised=
73112
"""
74113
Run the consistency model forward...?
75114
"""
115+
# sigma_min should be in original sigma space, not in karras sigma space
116+
# (e.g. not exponentiated by 1 / rho)
76117
c_skip, c_out, c_in = [
77118
append_dims(x, x_t.ndim)
78119
for x in self.get_scalings_for_boundary_condition(sigma, sigma_min, sigma_data=sigma_data)
@@ -88,26 +129,6 @@ def to_d(x, sigma, denoised):
88129
"""Converts a denoiser output to a Karras ODE derivative."""
89130
return (x - denoised) / append_dims(sigma, x.ndim)
90131

91-
def add_noise_to_input(
92-
self,
93-
sample: torch.FloatTensor,
94-
sigma_hat: float,
95-
sigma_min: float,
96-
sigma_max: float,
97-
s_noise: float = 1.0,
98-
generator: Optional[torch.Generator] = None,
99-
):
100-
# Clamp sigma_hat
101-
sigma_hat = sigma_hat.clamp(min=sigma_min, max=sigma_max)
102-
103-
# sample z ~ N(0, s_noise^2 * I)
104-
z = s_noise * randn_tensor(sample.shape, generator=generator, device=sample.device)
105-
106-
# tau = sigma_hat; eps = sigma_min
107-
sample_hat = sample + ((sigma_hat**2 - sigma_min**2) ** 0.5 * z)
108-
109-
return sample_hat
110-
111132
@torch.no_grad()
112133
def __call__(
113134
self,
@@ -144,69 +165,76 @@ def __call__(
144165
img_size = img_size = self.unet.config.sample_size
145166
shape = (batch_size, 3, img_size, img_size)
146167
device = self.device
147-
scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas")
148-
scheduler_has_sigma_min = hasattr(self.scheduler, "sigma_min")
149-
assert scheduler_has_sigma_min or scheduler_is_in_sigma_space, "Scheduler needs to have sigmas"
150168

151169
# 1. Sample image latents x_0 ~ N(0, sigma_0^2 * I)
152170
sample = randn_tensor(shape, generator=generator, device=device) * self.scheduler.init_noise_sigma
153171

154172
# 2. Set timesteps and get sigmas
173+
# Get sigma_min, sigma_max in original sigma space (not Karras sigma space)
174+
sigma_min, sigma_max = self.get_sigma_min_max_from_scheduler()
155175
self.scheduler.set_timesteps(num_inference_steps)
156176
timesteps = self.scheduler.timesteps
177+
178+
# Now get Karras sigma schedule (which I think the original implementation always uses)
179+
# See https://github.com/openai/consistency_models/blob/main/cm/karras_diffusion.py#L376
180+
sigmas = self.get_sigmas_from_scheduler()
157181

158182
# 3. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
159183
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
160184

161185
# 4. Denoising loop
162-
if scheduler_has_sigma_min:
163-
# 4.1 Scheduler which can add noise to input (e.g. KarrasVeScheduler)
164-
sigma_min = self.scheduler.sigma_min
165-
sigma_max = self.scheduler.sigma_max
166-
s_noise = self.scheduler.s_noise
167-
sigmas = self.scheduler.schedule
168-
186+
# TODO: hack, is there a better way to identify schedulers that implement the stochastic iterative sampling
187+
# similar to stochastic_iterative_sampler in the original code?
188+
if hasattr(self.scheduler, "add_noise_to_input"):
189+
# 4.1 Consistency Model Stochastic Iterative Scheduler (multi-step sampling)
169190
# First evaluate the consistency model. This will be the output sample if num_inference_steps == 1
170-
sigma = sigmas[timesteps[0]]
191+
# TODO: not all schedulers have an index_for_timestep method (e.g. KarrasVeScheduler)
192+
step_idx = self.scheduler.index_for_timestep(timesteps[0])
193+
sigma = sigmas[step_idx]
171194
_, sample = self.denoise(sample, sigma_min, sigma_data=sigma_data, clip_denoised=clip_denoised)
172195

173196
# If num_inference_steps > 1, perform multi-step sampling (stochastic_iterative_sampler)
174-
# Alternate adding noise and evaluating the consistency model
197+
# Alternate adding noise and evaluating the consistency model on the noised input
175198
for i, t in self.progress_bar(enumerate(self.scheduler.timesteps[1:])):
176-
sigma = sigmas[t]
177-
sigma_prev = sigmas[t - 1]
178-
if hasattr(self.scheduler, "add_noise_to_input"):
179-
sample_hat = self.scheduler.add_noise_to_input(sample, sigma, generator=generator)[0]
180-
else:
181-
sample_hat = self.add_noise_to_input(sample, sigma, sigma_prev, sigma_min, sigma_max, s_noise=s_noise, generator=generator)
182-
183-
_, sample = self.denoise(sample_hat, sigma, sigma_min, sigma_data=sigma_data, clip_denoised=clip_denoised)
184-
else:
199+
step_idx = self.scheduler.index_for_timestep(t)
200+
sigma = sigmas[step_idx]
201+
sigma_prev = sigmas[step_idx - 1]
202+
sample_hat, sigma_hat = self.scheduler.add_noise_to_input(sample, sigma, generator=generator)[0]
203+
204+
model_output, denoised = self.denoise(
205+
sample_hat, sigma, sigma_min, sigma_data=sigma_data, clip_denoised=clip_denoised
206+
)
207+
208+
sample = self.scheduler.step(denoised, sigma_hat, sigma_prev, sample_hat).prev_sample
209+
elif hasattr(self.scheduler, "sigmas"):
185210
# 4.2 Karras-style scheduler in sigma space (e.g. HeunDiscreteScheduler)
186-
sigma_min = self.scheduler.sigmas[-1]
187211
# TODO: warmup steps logic correct?
188212
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
189213
with self.progress_bar(total=num_inference_steps) as progress_bar:
190214
for i, t in enumerate(timesteps):
191215
step_idx = self.scheduler.index_for_timestep(t)
192216
sigma = self.scheduler.sigmas[step_idx]
193217
# TODO: handle class labels?
218+
# TODO: check shapes, might need equivalent of s_in in original code
219+
# See e.g. https://github.com/openai/consistency_models/blob/main/cm/karras_diffusion.py#L510
194220
model_output, denoised = self.denoise(
195221
sample, sigma, sigma_min, sigma_data=sigma_data, clip_denoised=clip_denoised
196222
)
197223

198224
# Karras-style schedulers already convert to a ODE derivative inside step()
199225
sample = self.scheduler.step(denoised, t, sample, **extra_step_kwargs).prev_sample
200226

201-
# TODO: need to handle karras sigma stuff here?
202-
203-
# TODO: differs from callback support in original code
227+
# Note: differs from callback support in original code
204228
# See e.g. https://github.com/openai/consistency_models/blob/main/cm/karras_diffusion.py#L459
205229
# call the callback, if provided
206230
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
207231
progress_bar.update()
208232
if callback is not None and i % callback_steps == 0:
209233
callback(i, t, sample)
234+
else:
235+
raise ValueError(
236+
f"Scheduler {self.scheduler.__class__} is not compatible with consistency models."
237+
)
210238

211239
# 5. Post-process image sample
212240
sample = (sample / 2 + 0.5).clamp(0, 1)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
# Copyright 2023 NVIDIA and The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
from dataclasses import dataclass
17+
from typing import Optional, Tuple, Union
18+
19+
import numpy as np
20+
import torch
21+
22+
from ..configuration_utils import ConfigMixin, register_to_config
23+
from ..utils import BaseOutput, randn_tensor
24+
from .scheduling_utils import SchedulerMixin
25+
26+
27+
def append_zero(x):
28+
return torch.cat([x, x.new_zeros([1])])
29+
30+
31+
@dataclass
32+
class CMStochasticIterativeSchedulerOutput(BaseOutput):
33+
"""
34+
Output class for the scheduler's step function output.
35+
Args:
36+
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
37+
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
38+
denoising loop.
39+
derivative (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
40+
Derivative of predicted original image sample (x_0).
41+
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
42+
The predicted denoised sample (x_{0}) based on the model output from the current timestep.
43+
`pred_original_sample` can be used to preview progress or for guidance.
44+
"""
45+
46+
prev_sample: torch.FloatTensor
47+
# derivative: torch.FloatTensor
48+
# pred_original_sample: Optional[torch.FloatTensor] = None
49+
50+
51+
class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
52+
"""
53+
Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and
54+
the VE column of Table 1 from [1] for reference.
55+
[1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models."
56+
https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic
57+
differential equations." https://arxiv.org/abs/2011.13456
58+
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
59+
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
60+
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
61+
[`~SchedulerMixin.from_pretrained`] functions.
62+
For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of
63+
Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the
64+
optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper.
65+
Args:
66+
sigma_min (`float`): minimum noise magnitude
67+
sigma_max (`float`): maximum noise magnitude
68+
s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling.
69+
A reasonable range is [1.000, 1.011].
70+
s_churn (`float`): the parameter controlling the overall amount of stochasticity.
71+
A reasonable range is [0, 100].
72+
s_min (`float`): the start value of the sigma range where we add noise (enable stochasticity).
73+
A reasonable range is [0, 10].
74+
s_max (`float`): the end value of the sigma range where we add noise.
75+
A reasonable range is [0.2, 80].
76+
"""
77+
78+
@register_to_config
79+
def __init__(
80+
self,
81+
sigma_data: float = 0.5,
82+
sigma_min: float = 0.002,
83+
sigma_max: float = 80.0,
84+
rho: float = 7.0,
85+
s_noise: float = 1.0,
86+
s_churn: float = 0.0,
87+
s_min: float = 0.0,
88+
s_max: float = float('inf'),
89+
):
90+
# standard deviation of the initial noise distribution
91+
self.init_noise_sigma = sigma_max
92+
93+
# setable values
94+
self.num_inference_steps: int = None
95+
self.timesteps: np.IntTensor = None
96+
self.schedule: torch.FloatTensor = None # sigma(t_i)
97+
98+
self.sigma_data = sigma_data
99+
self.rho = rho
100+
101+
def index_for_timestep(self, timestep, schedule_timesteps=None):
102+
if schedule_timesteps is None:
103+
schedule_timesteps = self.timesteps
104+
105+
indices = (schedule_timesteps == timestep).nonzero()
106+
return indices.item()
107+
108+
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
109+
"""
110+
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
111+
current timestep.
112+
Args:
113+
sample (`torch.FloatTensor`): input sample
114+
timestep (`int`, optional): current timestep
115+
Returns:
116+
`torch.FloatTensor`: scaled input sample
117+
"""
118+
return sample
119+
120+
def get_sigmas_karras(self):
121+
"""Constructs the noise schedule of Karras et al. (2022)."""
122+
ramp = np.linspace(0, 1, self.num_inference_steps)
123+
min_inv_rho = self.sigma_min ** (1 / self.rho)
124+
max_inv_rho = self.sigma_max ** (1 / self.rho)
125+
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** self.rho
126+
return append_zero(sigmas)
127+
128+
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
129+
"""
130+
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
131+
Args:
132+
num_inference_steps (`int`):
133+
the number of diffusion steps used when generating samples with a pre-trained model.
134+
"""
135+
self.num_inference_steps = num_inference_steps
136+
# TODO: how should timesteps be set? the original code seems to either solely work in sigma space or have
137+
# hardcoded timesteps (see e.g. https://github.com/openai/consistency_models/blob/main/scripts/launch.sh#L74)
138+
# TODO: should add num_train_timesteps here???
139+
timesteps = np.arange(0, self.num_inference_steps)[::-1].copy()
140+
sigmas = self.get_sigmas_karras()
141+
142+
self.timesteps = torch.from_numpy(timesteps).to(device)
143+
self.sigmas = torch.tensor(sigmas, dtype=torch.float32, device=device)
144+
145+
def add_noise(self, original_samples, noise, timesteps):
146+
"""Add noise for training."""
147+
raise NotImplementedError()
148+
149+
def add_noise_to_input(
150+
self,
151+
sample: torch.FloatTensor,
152+
sigma: float,
153+
generator: Optional[torch.Generator] = None
154+
) -> Tuple[torch.FloatTensor, float]:
155+
"""
156+
Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a
157+
higher noise level sigma_hat = sigma_i + gamma_i*sigma_i.
158+
TODO Args:
159+
"""
160+
sigma_min = self.config.sigma_min
161+
sigma_max = self.config.sigma_max
162+
163+
step_idx = (self.sigmas == sigma).nonzero().item()
164+
sigma_hat = self.sigmas[step_idx + 1].clamp(min=sigma_min, max=sigma_max)
165+
166+
# sample z ~ N(0, s_noise^2 * I)
167+
z = self.config.s_noise * randn_tensor(sample.shape, generator=generator, device=sample.device)
168+
169+
# tau = sigma_hat, eps = sigma_min
170+
sample_hat = sample + ((sigma_hat**2 - sigma_min**2) ** 0.5 * z)
171+
172+
return sample_hat, sigma_hat
173+
174+
def step(
175+
self,
176+
model_output: torch.FloatTensor,
177+
sigma_hat: float,
178+
sigma_prev: float,
179+
sample_hat: torch.FloatTensor,
180+
return_dict: bool = True,
181+
) -> Union[CMStochasticIterativeSchedulerOutput, Tuple]:
182+
"""
183+
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
184+
process from the learned model outputs (most often the predicted noise).
185+
Args:
186+
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
187+
sigma_hat (`float`): TODO
188+
sigma_prev (`float`): TODO
189+
sample_hat (`torch.FloatTensor`): TODO
190+
return_dict (`bool`): option for returning tuple rather than KarrasVeOutput class
191+
KarrasVeOutput: updated sample in the diffusion chain and derivative (TODO double check).
192+
Returns:
193+
[`~schedulers.scheduling_karras_ve.KarrasVeOutput`] or `tuple`:
194+
[`~schedulers.scheduling_karras_ve.KarrasVeOutput`] if `return_dict` is True, otherwise a `tuple`. When
195+
returning a tuple, the first element is the sample tensor.
196+
"""
197+
# Assume model output is the consistency model evaluated at sample_hat.
198+
sample_prev = model_output
199+
200+
if not return_dict:
201+
return (sample_prev,)
202+
203+
return CMStochasticIterativeSchedulerOutput(
204+
prev_sample=sample_prev,
205+
)
206+

0 commit comments

Comments
 (0)