diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 3399ee2c54cb..337c6603fe75 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -118,6 +118,17 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. + lambda_min_clipped (`float`, default `-inf`): + the clipping threshold for the minimum value of lambda(t) for numerical stability. This is critical for + cosine (squaredcos_cap_v2) noise schedule. + variance_type (`str`, *optional*): + Set to "learned" or "learned_range" for diffusion models that predict variance. For example, OpenAI's + guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the + Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on + diffusion ODEs. whether the model's output contains the predicted Gaussian variance. For example, OpenAI's + guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the + Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on + diffusion ODEs. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -140,6 +151,8 @@ def __init__( solver_type: str = "midpoint", lower_order_final: bool = True, use_karras_sigmas: Optional[bool] = False, + lambda_min_clipped: float = -float("inf"), + variance_type: Optional[str] = None, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -187,7 +200,7 @@ def __init__( self.lower_order_nums = 0 self.use_karras_sigmas = use_karras_sigmas - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None): """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -197,8 +210,11 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic device (`str` or `torch.device`, optional): the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ + # Clipping the minimum of all lambda(t) for numerical stability. + # This is critical for cosine (squaredcos_cap_v2) noise schedule. + clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.lambda_min_clipped) timesteps = ( - np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) + np.linspace(0, self.config.num_train_timesteps - 1 - clipped_idx, num_inference_steps + 1) .round()[::-1][:-1] .copy() .astype(np.int64) @@ -320,9 +336,13 @@ def convert_model_output( Returns: `torch.FloatTensor`: the converted model output. """ + # DPM-Solver++ needs to solve an integral of the data prediction model. if self.config.algorithm_type == "dpmsolver++": if self.config.prediction_type == "epsilon": + # DPM-Solver and DPM-Solver++ only need the "mean" output. + if self.config.variance_type in ["learned_range"]: + model_output = model_output[:, :3] alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] x0_pred = (sample - sigma_t * model_output) / alpha_t elif self.config.prediction_type == "sample": @@ -343,6 +363,9 @@ def convert_model_output( # DPM-Solver needs to solve an integral of the noise prediction model. elif self.config.algorithm_type == "dpmsolver": if self.config.prediction_type == "epsilon": + # DPM-Solver and DPM-Solver++ only need the "mean" output. + if self.config.variance_type in ["learned_range"]: + model_output = model_output[:, :3] return model_output elif self.config.prediction_type == "sample": alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index 049e2b1dbd4d..1d34977d4a57 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -113,6 +113,17 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): lower_order_final (`bool`, default `True`): whether to use lower-order solvers in the final steps. For singlestep schedulers, we recommend to enable this to use up all the function evaluations. + lambda_min_clipped (`float`, default `-inf`): + the clipping threshold for the minimum value of lambda(t) for numerical stability. This is critical for + cosine (squaredcos_cap_v2) noise schedule. + variance_type (`str`, *optional*): + Set to "learned" or "learned_range" for diffusion models that predict variance. For example, OpenAI's + guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the + Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on + diffusion ODEs. whether the model's output contains the predicted Gaussian variance. For example, OpenAI's + guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the + Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on + diffusion ODEs. """ @@ -135,6 +146,8 @@ def __init__( algorithm_type: str = "dpmsolver++", solver_type: str = "midpoint", lower_order_final: bool = True, + lambda_min_clipped: float = -float("inf"), + variance_type: Optional[str] = None, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -226,8 +239,11 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ self.num_inference_steps = num_inference_steps + # Clipping the minimum of all lambda(t) for numerical stability. + # This is critical for cosine (squaredcos_cap_v2) noise schedule. + clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.lambda_min_clipped) timesteps = ( - np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) + np.linspace(0, self.config.num_train_timesteps - 1 - clipped_idx, num_inference_steps + 1) .round()[::-1][:-1] .copy() .astype(np.int64) @@ -297,6 +313,9 @@ def convert_model_output( # DPM-Solver++ needs to solve an integral of the data prediction model. if self.config.algorithm_type == "dpmsolver++": if self.config.prediction_type == "epsilon": + # DPM-Solver and DPM-Solver++ only need the "mean" output. + if self.config.variance_type in ["learned_range"]: + model_output = model_output[:, :3] alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] x0_pred = (sample - sigma_t * model_output) / alpha_t elif self.config.prediction_type == "sample": @@ -317,6 +336,9 @@ def convert_model_output( # DPM-Solver needs to solve an integral of the noise prediction model. elif self.config.algorithm_type == "dpmsolver": if self.config.prediction_type == "epsilon": + # DPM-Solver and DPM-Solver++ only need the "mean" output. + if self.config.variance_type in ["learned_range"]: + model_output = model_output[:, :3] return model_output elif self.config.prediction_type == "sample": alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] diff --git a/tests/schedulers/test_scheduler_dpm_multi.py b/tests/schedulers/test_scheduler_dpm_multi.py index c1593bae3908..02a2a3882e94 100644 --- a/tests/schedulers/test_scheduler_dpm_multi.py +++ b/tests/schedulers/test_scheduler_dpm_multi.py @@ -29,6 +29,8 @@ def get_scheduler_config(self, **kwargs): "algorithm_type": "dpmsolver++", "solver_type": "midpoint", "lower_order_final": False, + "lambda_min_clipped": -float("inf"), + "variance_type": None, } config.update(**kwargs) @@ -187,6 +189,14 @@ def test_lower_order_final(self): self.check_over_configs(lower_order_final=True) self.check_over_configs(lower_order_final=False) + def test_lambda_min_clipped(self): + self.check_over_configs(lambda_min_clipped=-float("inf")) + self.check_over_configs(lambda_min_clipped=-5.1) + + def test_variance_type(self): + self.check_over_configs(variance_type=None) + self.check_over_configs(variance_type="learned_range") + def test_inference_steps(self): for num_inference_steps in [1, 2, 3, 5, 10, 50, 100, 999, 1000]: self.check_over_forward(num_inference_steps=num_inference_steps, time_step=0) diff --git a/tests/schedulers/test_scheduler_dpm_single.py b/tests/schedulers/test_scheduler_dpm_single.py index 9dff04e7c998..fd7395e794c7 100644 --- a/tests/schedulers/test_scheduler_dpm_single.py +++ b/tests/schedulers/test_scheduler_dpm_single.py @@ -28,6 +28,8 @@ def get_scheduler_config(self, **kwargs): "sample_max_value": 1.0, "algorithm_type": "dpmsolver++", "solver_type": "midpoint", + "lambda_min_clipped": -float("inf"), + "variance_type": None, } config.update(**kwargs) @@ -179,6 +181,14 @@ def test_lower_order_final(self): self.check_over_configs(lower_order_final=True) self.check_over_configs(lower_order_final=False) + def test_lambda_min_clipped(self): + self.check_over_configs(lambda_min_clipped=-float("inf")) + self.check_over_configs(lambda_min_clipped=-5.1) + + def test_variance_type(self): + self.check_over_configs(variance_type=None) + self.check_over_configs(variance_type="learned_range") + def test_inference_steps(self): for num_inference_steps in [1, 2, 3, 5, 10, 50, 100, 999, 1000]: self.check_over_forward(num_inference_steps=num_inference_steps, time_step=0)