diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index 9307db89d8d7..8ddd30b0a192 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -21,9 +21,13 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import logging from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): """ @@ -251,7 +255,14 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic self.timesteps = torch.from_numpy(timesteps).to(device) self.model_outputs = [None] * self.config.solver_order self.sample = None - self.orders = self.get_order_list(num_inference_steps) + + if not self.config.lower_order_final and num_inference_steps % self.config.solver_order != 0: + logger.warn( + "Changing scheduler {self.config} to have `lower_order_final` set to True to handle uneven amount of inference steps. Please make sure to always use an even number of `num_inference steps when using `lower_order_final=True`." + ) + self.register_to_config(lower_order_final=True) + + self.order_list = self.get_order_list(num_inference_steps) # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: @@ -597,6 +608,12 @@ def step( self.model_outputs[-1] = model_output order = self.order_list[step_index] + + # For img2img denoising might start with order>1 which is not possible + # In this case make sure that the first two steps are both order=1 + while self.model_outputs[-order] is None: + order -= 1 + # For single-step solvers, we use the initial value at each time with order = 1. if order == 1: self.sample = sample diff --git a/tests/schedulers/test_scheduler_dpm_single.py b/tests/schedulers/test_scheduler_dpm_single.py index fd7395e794c7..18a706a1f59b 100644 --- a/tests/schedulers/test_scheduler_dpm_single.py +++ b/tests/schedulers/test_scheduler_dpm_single.py @@ -116,6 +116,22 @@ def full_loop(self, scheduler=None, **config): return sample + def test_full_uneven_loop(self): + scheduler = DPMSolverSinglestepScheduler(**self.get_scheduler_config()) + num_inference_steps = 50 + model = self.dummy_model() + sample = self.dummy_sample_deter + scheduler.set_timesteps(num_inference_steps) + + # make sure that the first t is uneven + for i, t in enumerate(scheduler.timesteps[3:]): + residual = model(sample, t) + sample = scheduler.step(residual, t, sample).prev_sample + + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_mean.item() - 0.2574) < 1e-3 + def test_timesteps(self): for timesteps in [25, 50, 100, 999, 1000]: self.check_over_configs(num_train_timesteps=timesteps)