From 3930d38f2a1004d863e93639459d7b92ab4bf1aa Mon Sep 17 00:00:00 2001 From: LuChengTHU Date: Wed, 3 May 2023 01:10:33 +0800 Subject: [PATCH 1/9] fix multistep dpmsolver for cosine schedule (deepfloy-if) --- .../scheduling_dpmsolver_multistep.py | 21 +++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 3399ee2c54cb..8e4f49029a80 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -118,6 +118,13 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. + lambda_min_clipped (`float`, default `-5.1`): + the clipping threshold for the minimum value of lambda(t) for numerical stability. This is critical for + cosine (squaredcos_cap_v2) noise schedule. + is_predicting_variance (`bool`, default `False`): + whether the model's output contains the predicted Gaussian variance. For example, OpenAI's guided-diffusion + (https://github.com/openai/guided-diffusion) predicts both mean and variance of the Gaussian distribution + in the model's output. DPM-Solver only needs the "mean" output because it is based on diffusion ODEs. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -140,6 +147,8 @@ def __init__( solver_type: str = "midpoint", lower_order_final: bool = True, use_karras_sigmas: Optional[bool] = False, + lambda_min_clipped: float = -5.1, + is_predicting_variance: bool = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -187,7 +196,7 @@ def __init__( self.lower_order_nums = 0 self.use_karras_sigmas = use_karras_sigmas - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None): """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -197,8 +206,11 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic device (`str` or `torch.device`, optional): the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ + # Clipping the minimum of all lambda(t) for numerical stability. + # This is critical for cosine (squaredcos_cap_v2) noise schedule. + clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.lambda_min_clipped) timesteps = ( - np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) + np.linspace(0, self.config.num_train_timesteps - 1 - clipped_idx, num_inference_steps + 1) .round()[::-1][:-1] .copy() .astype(np.int64) @@ -320,6 +332,11 @@ def convert_model_output( Returns: `torch.FloatTensor`: the converted model output. """ + + # DPM-Solver and DPM-Solver++ only need the "mean" output. + if self.is_predicting_variance: + model_output = model_output[:, :3] + # DPM-Solver++ needs to solve an integral of the data prediction model. if self.config.algorithm_type == "dpmsolver++": if self.config.prediction_type == "epsilon": From 4d710370b2ff4c2360a8855f71cb7f23aef37d04 Mon Sep 17 00:00:00 2001 From: LuChengTHU Date: Wed, 3 May 2023 01:49:00 +0800 Subject: [PATCH 2/9] fix a typo --- src/diffusers/schedulers/scheduling_dpmsolver_multistep.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 8e4f49029a80..5daeafc846fb 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -333,13 +333,12 @@ def convert_model_output( `torch.FloatTensor`: the converted model output. """ - # DPM-Solver and DPM-Solver++ only need the "mean" output. - if self.is_predicting_variance: - model_output = model_output[:, :3] - # DPM-Solver++ needs to solve an integral of the data prediction model. if self.config.algorithm_type == "dpmsolver++": if self.config.prediction_type == "epsilon": + # DPM-Solver and DPM-Solver++ only need the "mean" output. + if self.is_predicting_variance: + model_output = model_output[:, :3] alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] x0_pred = (sample - sigma_t * model_output) / alpha_t elif self.config.prediction_type == "sample": From 2ee131cc647d1f19b58a531170245ff12738b74f Mon Sep 17 00:00:00 2001 From: Cheng Lu Date: Wed, 3 May 2023 21:02:16 +0800 Subject: [PATCH 3/9] Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py Co-authored-by: Patrick von Platen --- src/diffusers/schedulers/scheduling_dpmsolver_multistep.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 5daeafc846fb..bb5c73424e63 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -147,7 +147,7 @@ def __init__( solver_type: str = "midpoint", lower_order_final: bool = True, use_karras_sigmas: Optional[bool] = False, - lambda_min_clipped: float = -5.1, + lambda_min_clipped: float = -float("inf"), is_predicting_variance: bool = False, ): if trained_betas is not None: From fe00e172a8c91ba44cf6d30b9d9e874e9bc80d72 Mon Sep 17 00:00:00 2001 From: Cheng Lu Date: Wed, 3 May 2023 21:03:00 +0800 Subject: [PATCH 4/9] Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py Co-authored-by: Patrick von Platen --- src/diffusers/schedulers/scheduling_dpmsolver_multistep.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index bb5c73424e63..e1ee256c70bf 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -121,7 +121,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): lambda_min_clipped (`float`, default `-5.1`): the clipping threshold for the minimum value of lambda(t) for numerical stability. This is critical for cosine (squaredcos_cap_v2) noise schedule. - is_predicting_variance (`bool`, default `False`): + variance_type (`str`, *optional*): + Set to "learned" or "learned_range" for diffusion models that predict variance. For example, OpenAI's guided-diffusion + (https://github.com/openai/guided-diffusion) predicts both mean and variance of the Gaussian distribution + in the model's output. DPM-Solver only needs the "mean" output because it is based on diffusion ODEs. whether the model's output contains the predicted Gaussian variance. For example, OpenAI's guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on diffusion ODEs. From 5582d64067d1390729aa5c40c8ce0dae1df197ba Mon Sep 17 00:00:00 2001 From: Cheng Lu Date: Wed, 3 May 2023 21:03:10 +0800 Subject: [PATCH 5/9] Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py Co-authored-by: Patrick von Platen --- src/diffusers/schedulers/scheduling_dpmsolver_multistep.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index e1ee256c70bf..74a3dbfee95a 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -151,7 +151,7 @@ def __init__( lower_order_final: bool = True, use_karras_sigmas: Optional[bool] = False, lambda_min_clipped: float = -float("inf"), - is_predicting_variance: bool = False, + variance_type: Optional[str] = None, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) From 928f6a978f2c290427121fee8bdb2103b32e051d Mon Sep 17 00:00:00 2001 From: Cheng Lu Date: Wed, 3 May 2023 21:03:20 +0800 Subject: [PATCH 6/9] Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py Co-authored-by: Patrick von Platen --- src/diffusers/schedulers/scheduling_dpmsolver_multistep.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 74a3dbfee95a..abd39f567a73 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -118,7 +118,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. - lambda_min_clipped (`float`, default `-5.1`): + lambda_min_clipped (`float`, default `-inf`): the clipping threshold for the minimum value of lambda(t) for numerical stability. This is critical for cosine (squaredcos_cap_v2) noise schedule. variance_type (`str`, *optional*): From 1b9bee8180056adc531eb2869cbb24faf0d960c3 Mon Sep 17 00:00:00 2001 From: Cheng Lu Date: Wed, 3 May 2023 21:03:32 +0800 Subject: [PATCH 7/9] Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py Co-authored-by: Patrick von Platen --- src/diffusers/schedulers/scheduling_dpmsolver_multistep.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index abd39f567a73..0b844814608f 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -340,7 +340,7 @@ def convert_model_output( if self.config.algorithm_type == "dpmsolver++": if self.config.prediction_type == "epsilon": # DPM-Solver and DPM-Solver++ only need the "mean" output. - if self.is_predicting_variance: + if self.config.variance_type in ["learned_range"]: model_output = model_output[:, :3] alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] x0_pred = (sample - sigma_t * model_output) / alpha_t From 376ca13b9072b16b8d42032c8b49e9a652ca680f Mon Sep 17 00:00:00 2001 From: LuChengTHU Date: Wed, 3 May 2023 21:26:12 +0800 Subject: [PATCH 8/9] update all dpmsolver (singlestep, multistep, dpm, dpm++) for cosine noise schedule --- .../scheduling_dpmsolver_multistep.py | 3 +++ .../scheduling_dpmsolver_singlestep.py | 23 ++++++++++++++++++- 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 0b844814608f..ec6300c5f2ea 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -362,6 +362,9 @@ def convert_model_output( # DPM-Solver needs to solve an integral of the noise prediction model. elif self.config.algorithm_type == "dpmsolver": if self.config.prediction_type == "epsilon": + # DPM-Solver and DPM-Solver++ only need the "mean" output. + if self.config.variance_type in ["learned_range"]: + model_output = model_output[:, :3] return model_output elif self.config.prediction_type == "sample": alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index 049e2b1dbd4d..c6c855d16957 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -113,6 +113,16 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): lower_order_final (`bool`, default `True`): whether to use lower-order solvers in the final steps. For singlestep schedulers, we recommend to enable this to use up all the function evaluations. + lambda_min_clipped (`float`, default `-inf`): + the clipping threshold for the minimum value of lambda(t) for numerical stability. This is critical for + cosine (squaredcos_cap_v2) noise schedule. + variance_type (`str`, *optional*): + Set to "learned" or "learned_range" for diffusion models that predict variance. For example, OpenAI's guided-diffusion + (https://github.com/openai/guided-diffusion) predicts both mean and variance of the Gaussian distribution + in the model's output. DPM-Solver only needs the "mean" output because it is based on diffusion ODEs. + whether the model's output contains the predicted Gaussian variance. For example, OpenAI's guided-diffusion + (https://github.com/openai/guided-diffusion) predicts both mean and variance of the Gaussian distribution + in the model's output. DPM-Solver only needs the "mean" output because it is based on diffusion ODEs. """ @@ -135,6 +145,8 @@ def __init__( algorithm_type: str = "dpmsolver++", solver_type: str = "midpoint", lower_order_final: bool = True, + lambda_min_clipped: float = -float("inf"), + variance_type: Optional[str] = None, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -226,8 +238,11 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ self.num_inference_steps = num_inference_steps + # Clipping the minimum of all lambda(t) for numerical stability. + # This is critical for cosine (squaredcos_cap_v2) noise schedule. + clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.lambda_min_clipped) timesteps = ( - np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) + np.linspace(0, self.config.num_train_timesteps - 1 - clipped_idx, num_inference_steps + 1) .round()[::-1][:-1] .copy() .astype(np.int64) @@ -297,6 +312,9 @@ def convert_model_output( # DPM-Solver++ needs to solve an integral of the data prediction model. if self.config.algorithm_type == "dpmsolver++": if self.config.prediction_type == "epsilon": + # DPM-Solver and DPM-Solver++ only need the "mean" output. + if self.config.variance_type in ["learned_range"]: + model_output = model_output[:, :3] alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] x0_pred = (sample - sigma_t * model_output) / alpha_t elif self.config.prediction_type == "sample": @@ -317,6 +335,9 @@ def convert_model_output( # DPM-Solver needs to solve an integral of the noise prediction model. elif self.config.algorithm_type == "dpmsolver": if self.config.prediction_type == "epsilon": + # DPM-Solver and DPM-Solver++ only need the "mean" output. + if self.config.variance_type in ["learned_range"]: + model_output = model_output[:, :3] return model_output elif self.config.prediction_type == "sample": alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] From 9ace742d11be45bf7e4bb99da599f480fde4c822 Mon Sep 17 00:00:00 2001 From: LuChengTHU Date: Wed, 3 May 2023 21:29:04 +0800 Subject: [PATCH 9/9] add test, fix style --- .../schedulers/scheduling_dpmsolver_multistep.py | 13 +++++++------ .../schedulers/scheduling_dpmsolver_singlestep.py | 13 +++++++------ tests/schedulers/test_scheduler_dpm_multi.py | 10 ++++++++++ tests/schedulers/test_scheduler_dpm_single.py | 10 ++++++++++ 4 files changed, 34 insertions(+), 12 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index ec6300c5f2ea..337c6603fe75 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -122,12 +122,13 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): the clipping threshold for the minimum value of lambda(t) for numerical stability. This is critical for cosine (squaredcos_cap_v2) noise schedule. variance_type (`str`, *optional*): - Set to "learned" or "learned_range" for diffusion models that predict variance. For example, OpenAI's guided-diffusion - (https://github.com/openai/guided-diffusion) predicts both mean and variance of the Gaussian distribution - in the model's output. DPM-Solver only needs the "mean" output because it is based on diffusion ODEs. - whether the model's output contains the predicted Gaussian variance. For example, OpenAI's guided-diffusion - (https://github.com/openai/guided-diffusion) predicts both mean and variance of the Gaussian distribution - in the model's output. DPM-Solver only needs the "mean" output because it is based on diffusion ODEs. + Set to "learned" or "learned_range" for diffusion models that predict variance. For example, OpenAI's + guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the + Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on + diffusion ODEs. whether the model's output contains the predicted Gaussian variance. For example, OpenAI's + guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the + Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on + diffusion ODEs. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index c6c855d16957..1d34977d4a57 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -117,12 +117,13 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): the clipping threshold for the minimum value of lambda(t) for numerical stability. This is critical for cosine (squaredcos_cap_v2) noise schedule. variance_type (`str`, *optional*): - Set to "learned" or "learned_range" for diffusion models that predict variance. For example, OpenAI's guided-diffusion - (https://github.com/openai/guided-diffusion) predicts both mean and variance of the Gaussian distribution - in the model's output. DPM-Solver only needs the "mean" output because it is based on diffusion ODEs. - whether the model's output contains the predicted Gaussian variance. For example, OpenAI's guided-diffusion - (https://github.com/openai/guided-diffusion) predicts both mean and variance of the Gaussian distribution - in the model's output. DPM-Solver only needs the "mean" output because it is based on diffusion ODEs. + Set to "learned" or "learned_range" for diffusion models that predict variance. For example, OpenAI's + guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the + Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on + diffusion ODEs. whether the model's output contains the predicted Gaussian variance. For example, OpenAI's + guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the + Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on + diffusion ODEs. """ diff --git a/tests/schedulers/test_scheduler_dpm_multi.py b/tests/schedulers/test_scheduler_dpm_multi.py index c1593bae3908..02a2a3882e94 100644 --- a/tests/schedulers/test_scheduler_dpm_multi.py +++ b/tests/schedulers/test_scheduler_dpm_multi.py @@ -29,6 +29,8 @@ def get_scheduler_config(self, **kwargs): "algorithm_type": "dpmsolver++", "solver_type": "midpoint", "lower_order_final": False, + "lambda_min_clipped": -float("inf"), + "variance_type": None, } config.update(**kwargs) @@ -187,6 +189,14 @@ def test_lower_order_final(self): self.check_over_configs(lower_order_final=True) self.check_over_configs(lower_order_final=False) + def test_lambda_min_clipped(self): + self.check_over_configs(lambda_min_clipped=-float("inf")) + self.check_over_configs(lambda_min_clipped=-5.1) + + def test_variance_type(self): + self.check_over_configs(variance_type=None) + self.check_over_configs(variance_type="learned_range") + def test_inference_steps(self): for num_inference_steps in [1, 2, 3, 5, 10, 50, 100, 999, 1000]: self.check_over_forward(num_inference_steps=num_inference_steps, time_step=0) diff --git a/tests/schedulers/test_scheduler_dpm_single.py b/tests/schedulers/test_scheduler_dpm_single.py index 9dff04e7c998..fd7395e794c7 100644 --- a/tests/schedulers/test_scheduler_dpm_single.py +++ b/tests/schedulers/test_scheduler_dpm_single.py @@ -28,6 +28,8 @@ def get_scheduler_config(self, **kwargs): "sample_max_value": 1.0, "algorithm_type": "dpmsolver++", "solver_type": "midpoint", + "lambda_min_clipped": -float("inf"), + "variance_type": None, } config.update(**kwargs) @@ -179,6 +181,14 @@ def test_lower_order_final(self): self.check_over_configs(lower_order_final=True) self.check_over_configs(lower_order_final=False) + def test_lambda_min_clipped(self): + self.check_over_configs(lambda_min_clipped=-float("inf")) + self.check_over_configs(lambda_min_clipped=-5.1) + + def test_variance_type(self): + self.check_over_configs(variance_type=None) + self.check_over_configs(variance_type="learned_range") + def test_inference_steps(self): for num_inference_steps in [1, 2, 3, 5, 10, 50, 100, 999, 1000]: self.check_over_forward(num_inference_steps=num_inference_steps, time_step=0)