Skip to content

Commit 5566a2b

Browse files
committed
add more explanations for the stabilizing trick (for steps < 15)
1 parent a6efda1 commit 5566a2b

File tree

3 files changed

+22
-16
lines changed

3 files changed

+22
-16
lines changed

src/diffusers/schedulers/scheduling_dpmsolver_multistep.py

+12-7
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
103103
solver_type (`str`, default `dpm_solver`):
104104
the solver type for the second-order solver. Either `dpm_solver` or `taylor`. The solver type slightly
105105
affects the sample quality, especially for small number of steps.
106-
denoise_final (`bool`, default `True`):
107-
whether to use lower-order solvers in the final steps.
106+
lower_order_final (`bool`, default `True`):
107+
whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically
108+
find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10.
108109
109110
"""
110111

@@ -131,7 +132,7 @@ def __init__(
131132
dynamic_thresholding_ratio: float = 0.995,
132133
sample_max_value: float = 1.0,
133134
solver_type: str = "dpm_solver",
134-
denoise_final: bool = True,
135+
lower_order_final: bool = True,
135136
):
136137
if trained_betas is not None:
137138
self.betas = torch.from_numpy(trained_betas)
@@ -405,17 +406,21 @@ def step(
405406
else:
406407
step_index = step_index.item()
407408
prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1]
408-
denoise_final = (step_index == len(self.timesteps) - 1) and self.config.denoise_final
409-
denoise_second = (step_index == len(self.timesteps) - 2) and self.config.denoise_final
409+
lower_order_final = (
410+
(step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15
411+
)
412+
lower_order_second = (
413+
(step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
414+
)
410415

411416
model_output = self.convert_model_output(model_output, timestep, sample)
412417
for i in range(self.config.solver_order - 1):
413418
self.model_outputs[i] = self.model_outputs[i + 1]
414419
self.model_outputs[-1] = model_output
415420

416-
if self.config.solver_order == 1 or self.lower_order_nums < 1 or denoise_final:
421+
if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
417422
prev_sample = self.dpm_solver_first_order_update(model_output, timestep, prev_timestep, sample)
418-
elif self.config.solver_order == 2 or self.lower_order_nums < 2 or denoise_second:
423+
elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
419424
timestep_list = [self.timesteps[step_index - 1], timestep]
420425
prev_sample = self.multistep_dpm_solver_second_order_update(
421426
self.model_outputs, timestep_list, prev_timestep, sample

src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,9 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
130130
solver_type (`str`, default `dpm_solver`):
131131
the solver type for the second-order solver. Either `dpm_solver` or `taylor`. The solver type slightly
132132
affects the sample quality, especially for small number of steps.
133-
denoise_final (`bool`, default `True`):
134-
whether to use lower-order solvers in the final steps.
133+
lower_order_final (`bool`, default `True`):
134+
whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically
135+
find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10.
135136
136137
"""
137138

@@ -153,7 +154,7 @@ def __init__(
153154
dynamic_thresholding_ratio: float = 0.995,
154155
sample_max_value: float = 1.0,
155156
solver_type: str = "dpm_solver",
156-
denoise_final: bool = True,
157+
lower_order_final: bool = True,
157158
):
158159
if trained_betas is not None:
159160
self.betas = jnp.asarray(trained_betas)
@@ -471,7 +472,7 @@ def step_3(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray:
471472

472473
if self.config.solver_order == 2:
473474
return step_2(state)
474-
elif self.config.denoise_final:
475+
elif self.config.lower_order_final:
475476
return jax.lax.cond(
476477
state.lower_order_nums < 2,
477478
step_2,
@@ -493,7 +494,7 @@ def step_3(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray:
493494

494495
if self.config.solver_order == 1:
495496
prev_sample = step_1(state)
496-
elif self.config.denoise_final:
497+
elif self.config.lower_order_final:
497498
prev_sample = jax.lax.cond(
498499
state.lower_order_nums < 1,
499500
step_1,

tests/test_scheduler.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -565,7 +565,7 @@ def get_scheduler_config(self, **kwargs):
565565
"thresholding": False,
566566
"sample_max_value": 1.0,
567567
"solver_type": "dpm_solver",
568-
"denoise_final": False,
568+
"lower_order_final": False,
569569
}
570570

571571
config.update(**kwargs)
@@ -702,9 +702,9 @@ def test_solver_order_and_type(self):
702702
sample = self.full_loop(solver_order=order, solver_type=solver_type, predict_x0=predict_x0)
703703
assert not torch.isnan(sample).any(), "Samples have nan numbers"
704704

705-
def test_denoise_final(self):
706-
self.check_over_configs(denoise_final=True)
707-
self.check_over_configs(denoise_final=False)
705+
def test_lower_order_final(self):
706+
self.check_over_configs(lower_order_final=True)
707+
self.check_over_configs(lower_order_final=False)
708708

709709
def test_inference_steps(self):
710710
for num_inference_steps in [1, 2, 3, 5, 10, 50, 100, 999, 1000]:

0 commit comments

Comments
 (0)