@@ -103,8 +103,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
103
103
solver_type (`str`, default `dpm_solver`):
104
104
the solver type for the second-order solver. Either `dpm_solver` or `taylor`. The solver type slightly
105
105
affects the sample quality, especially for small number of steps.
106
- denoise_final (`bool`, default `True`):
107
- whether to use lower-order solvers in the final steps.
106
+ lower_order_final (`bool`, default `True`):
107
+ whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically
108
+ find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10.
108
109
109
110
"""
110
111
@@ -131,7 +132,7 @@ def __init__(
131
132
dynamic_thresholding_ratio : float = 0.995 ,
132
133
sample_max_value : float = 1.0 ,
133
134
solver_type : str = "dpm_solver" ,
134
- denoise_final : bool = True ,
135
+ lower_order_final : bool = True ,
135
136
):
136
137
if trained_betas is not None :
137
138
self .betas = torch .from_numpy (trained_betas )
@@ -405,17 +406,21 @@ def step(
405
406
else :
406
407
step_index = step_index .item ()
407
408
prev_timestep = 0 if step_index == len (self .timesteps ) - 1 else self .timesteps [step_index + 1 ]
408
- denoise_final = (step_index == len (self .timesteps ) - 1 ) and self .config .denoise_final
409
- denoise_second = (step_index == len (self .timesteps ) - 2 ) and self .config .denoise_final
409
+ lower_order_final = (
410
+ (step_index == len (self .timesteps ) - 1 ) and self .config .lower_order_final and len (self .timesteps ) < 15
411
+ )
412
+ lower_order_second = (
413
+ (step_index == len (self .timesteps ) - 2 ) and self .config .lower_order_final and len (self .timesteps ) < 15
414
+ )
410
415
411
416
model_output = self .convert_model_output (model_output , timestep , sample )
412
417
for i in range (self .config .solver_order - 1 ):
413
418
self .model_outputs [i ] = self .model_outputs [i + 1 ]
414
419
self .model_outputs [- 1 ] = model_output
415
420
416
- if self .config .solver_order == 1 or self .lower_order_nums < 1 or denoise_final :
421
+ if self .config .solver_order == 1 or self .lower_order_nums < 1 or lower_order_final :
417
422
prev_sample = self .dpm_solver_first_order_update (model_output , timestep , prev_timestep , sample )
418
- elif self .config .solver_order == 2 or self .lower_order_nums < 2 or denoise_second :
423
+ elif self .config .solver_order == 2 or self .lower_order_nums < 2 or lower_order_second :
419
424
timestep_list = [self .timesteps [step_index - 1 ], timestep ]
420
425
prev_sample = self .multistep_dpm_solver_second_order_update (
421
426
self .model_outputs , timestep_list , prev_timestep , sample
0 commit comments