Skip to content

Commit 194b0a4

Browse files
authored
Add use_Karras_sigmas to DPMSolverSinglestepScheduler (#3476)
* add use_karras_sigmas * add karras test * add doc
1 parent 6dd3871 commit 194b0a4

File tree

2 files changed

+64
-0
lines changed

2 files changed

+64
-0
lines changed

src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,10 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
117117
lower_order_final (`bool`, default `True`):
118118
whether to use lower-order solvers in the final steps. For singlestep schedulers, we recommend to enable
119119
this to use up all the function evaluations.
120+
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
121+
This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the
122+
noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence
123+
of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf.
120124
lambda_min_clipped (`float`, default `-inf`):
121125
the clipping threshold for the minimum value of lambda(t) for numerical stability. This is critical for
122126
cosine (squaredcos_cap_v2) noise schedule.
@@ -150,6 +154,7 @@ def __init__(
150154
algorithm_type: str = "dpmsolver++",
151155
solver_type: str = "midpoint",
152156
lower_order_final: bool = True,
157+
use_karras_sigmas: Optional[bool] = False,
153158
lambda_min_clipped: float = -float("inf"),
154159
variance_type: Optional[str] = None,
155160
):
@@ -197,6 +202,7 @@ def __init__(
197202
self.model_outputs = [None] * solver_order
198203
self.sample = None
199204
self.order_list = self.get_order_list(num_train_timesteps)
205+
self.use_karras_sigmas = use_karras_sigmas
200206

201207
def get_order_list(self, num_inference_steps: int) -> List[int]:
202208
"""
@@ -252,6 +258,14 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
252258
.copy()
253259
.astype(np.int64)
254260
)
261+
262+
if self.use_karras_sigmas:
263+
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
264+
log_sigmas = np.log(sigmas)
265+
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
266+
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
267+
timesteps = np.flip(timesteps).copy().astype(np.int64)
268+
255269
self.timesteps = torch.from_numpy(timesteps).to(device)
256270
self.model_outputs = [None] * self.config.solver_order
257271
self.sample = None
@@ -299,6 +313,44 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
299313

300314
return sample
301315

316+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
317+
def _sigma_to_t(self, sigma, log_sigmas):
318+
# get log sigma
319+
log_sigma = np.log(sigma)
320+
321+
# get distribution
322+
dists = log_sigma - log_sigmas[:, np.newaxis]
323+
324+
# get sigmas range
325+
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
326+
high_idx = low_idx + 1
327+
328+
low = log_sigmas[low_idx]
329+
high = log_sigmas[high_idx]
330+
331+
# interpolate sigmas
332+
w = (low - log_sigma) / (low - high)
333+
w = np.clip(w, 0, 1)
334+
335+
# transform interpolation to time range
336+
t = (1 - w) * low_idx + w * high_idx
337+
t = t.reshape(sigma.shape)
338+
return t
339+
340+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
341+
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
342+
"""Constructs the noise schedule of Karras et al. (2022)."""
343+
344+
sigma_min: float = in_sigmas[-1].item()
345+
sigma_max: float = in_sigmas[0].item()
346+
347+
rho = 7.0 # 7.0 is the value used in the paper
348+
ramp = np.linspace(0, 1, num_inference_steps)
349+
min_inv_rho = sigma_min ** (1 / rho)
350+
max_inv_rho = sigma_max ** (1 / rho)
351+
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
352+
return sigmas
353+
302354
def convert_model_output(
303355
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
304356
) -> torch.FloatTensor:

tests/schedulers/test_scheduler_dpm_single.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,12 +215,24 @@ def test_full_loop_no_noise(self):
215215

216216
assert abs(result_mean.item() - 0.2791) < 1e-3
217217

218+
def test_full_loop_with_karras(self):
219+
sample = self.full_loop(use_karras_sigmas=True)
220+
result_mean = torch.mean(torch.abs(sample))
221+
222+
assert abs(result_mean.item() - 0.2248) < 1e-3
223+
218224
def test_full_loop_with_v_prediction(self):
219225
sample = self.full_loop(prediction_type="v_prediction")
220226
result_mean = torch.mean(torch.abs(sample))
221227

222228
assert abs(result_mean.item() - 0.1453) < 1e-3
223229

230+
def test_full_loop_with_karras_and_v_prediction(self):
231+
sample = self.full_loop(prediction_type="v_prediction", use_karras_sigmas=True)
232+
result_mean = torch.mean(torch.abs(sample))
233+
234+
assert abs(result_mean.item() - 0.0649) < 1e-3
235+
224236
def test_fp16_support(self):
225237
scheduler_class = self.scheduler_classes[0]
226238
scheduler_config = self.get_scheduler_config(thresholding=True, dynamic_thresholding_ratio=0)

0 commit comments

Comments
 (0)