From 326172d1b28500baacc408d4cd2c7628022d394d Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 3 Jul 2023 15:55:16 +0200 Subject: [PATCH 01/30] Add timestep_spacing to DDPM, LMSDiscrete, PNDM. --- src/diffusers/schedulers/scheduling_ddim.py | 2 +- src/diffusers/schedulers/scheduling_ddpm.py | 35 ++++++++++++++-- .../schedulers/scheduling_lms_discrete.py | 40 +++++++++++++++++-- src/diffusers/schedulers/scheduling_pndm.py | 29 +++++++++++--- 4 files changed, 93 insertions(+), 13 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index bab6f8acea03..3ffa58c55f2d 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -302,7 +302,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic self.num_inference_steps = num_inference_steps - # "leading" and "trailing" corresponds to annotation of Table 1. of https://arxiv.org/abs/2305.08891 + # "leading" and "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 5d24766d68c7..43d252f1c3e8 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -114,6 +114,13 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): (https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`. sample_max_value (`float`, default `1.0`): the threshold value for dynamic thresholding. Valid only when `thresholding=True`. + timestep_spacing (`str`, default `"leading"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -134,6 +141,8 @@ def __init__( dynamic_thresholding_ratio: float = 0.995, clip_sample_range: float = 1.0, sample_max_value: float = 1.0, + timestep_spacing: str = "leading", + steps_offset: int = 1, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -228,11 +237,31 @@ def set_timesteps( ) self.num_inference_steps = num_inference_steps - - step_ratio = self.config.num_train_timesteps // self.num_inference_steps - timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) self.custom_timesteps = False + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps).round()[ + ::-1 + ].copy().astype(np.int64) + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + + self.timesteps = torch.from_numpy(timesteps).to(device) def _get_variance(self, t, predicted_variance=None, variance_type=None): diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 0656475c3093..cd323f0de140 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -102,6 +102,13 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) + timestep_spacing (`str`, default `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -117,6 +124,8 @@ def __init__( trained_betas: Optional[Union[np.ndarray, List[float]]] = None, use_karras_sigmas: Optional[bool] = False, prediction_type: str = "epsilon", + timestep_spacing: str = "linspace", + steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -140,9 +149,6 @@ def __init__( sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) - # standard deviation of the initial noise distribution - self.init_noise_sigma = self.sigmas.max() - # setable values self.num_inference_steps = None self.use_karras_sigmas = use_karras_sigmas @@ -150,6 +156,14 @@ def __init__( self.derivatives = [] self.is_scale_input_called = False + @property + def init_noise_sigma(self): + # standard deviation of the initial noise distribution + if self.config.timestep_spacing == "linspace": + return self.sigmas.max() + + return (self.sigmas.max() ** 2 + 1) ** 0.5 + def scale_model_input( self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] ) -> torch.FloatTensor: @@ -207,6 +221,26 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(float) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) log_sigmas = np.log(sigmas) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index 01c02a21bbfc..7ce348c7e2d0 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -85,11 +85,13 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): prediction_type (`str`, default `epsilon`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) + timestep_spacing (`str`, default `"leading"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. steps_offset (`int`, default `0`): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. - """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -106,6 +108,7 @@ def __init__( skip_prk_steps: bool = False, set_alpha_to_one: bool = False, prediction_type: str = "epsilon", + timestep_spacing: str = "leading", steps_offset: int = 0, ): if trained_betas is not None: @@ -159,11 +162,25 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic """ self.num_inference_steps = num_inference_steps - step_ratio = self.config.num_train_timesteps // self.num_inference_steps - # creates integer timesteps by multiplying by ratio - # casting to int to avoid issues when num_inference_step is power of 3 - self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round() - self._timesteps += self.config.steps_offset + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + self._timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps).round().astype(np.int64) + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round() + self._timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + self._timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio))[::-1].astype(np.int64) + self._timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) if self.config.skip_prk_steps: # for some models like stable diffusion the prk steps can/should be skipped to From 9ac281b0ee8eef696860fc63d2e1c32db6240050 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 3 Jul 2023 16:02:23 +0200 Subject: [PATCH 02/30] Remove spurious line. --- src/diffusers/schedulers/scheduling_lms_discrete.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index cd323f0de140..c8fda1df8021 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -219,8 +219,6 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic """ self.num_inference_steps = num_inference_steps - timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() - # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() From 77075af98a375b7cfb0d42c676c535465a9f4ecb Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 3 Jul 2023 16:27:10 +0200 Subject: [PATCH 03/30] More easy schedulers. --- .../scheduling_euler_ancestral_discrete.py | 41 +++++++++++++++++-- .../schedulers/scheduling_euler_discrete.py | 21 +++++++++- .../schedulers/scheduling_heun_discrete.py | 40 ++++++++++++++++-- .../scheduling_k_dpm_2_ancestral_discrete.py | 40 ++++++++++++++++-- .../schedulers/scheduling_k_dpm_2_discrete.py | 40 ++++++++++++++++-- 5 files changed, 165 insertions(+), 17 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py index 6b08e9bfc207..5ec6118c6405 100644 --- a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py @@ -99,6 +99,13 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) + timestep_spacing (`str`, default `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. """ @@ -114,6 +121,8 @@ def __init__( beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, prediction_type: str = "epsilon", + timestep_spacing: str = "linspace", + steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -137,15 +146,20 @@ def __init__( sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) - # standard deviation of the initial noise distribution - self.init_noise_sigma = self.sigmas.max() - # setable values self.num_inference_steps = None timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() self.timesteps = torch.from_numpy(timesteps) self.is_scale_input_called = False + @property + def init_noise_sigma(self): + # standard deviation of the initial noise distribution + if self.config.timestep_spacing == "linspace": + return self.sigmas.max() + + return (self.sigmas.max() ** 2 + 1) ** 0.5 + def scale_model_input( self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] ) -> torch.FloatTensor: @@ -179,7 +193,26 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic """ self.num_inference_steps = num_inference_steps - timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(float) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index 7237128cbf07..af62fdfac259 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -191,7 +191,26 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic """ self.num_inference_steps = num_inference_steps - timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(float) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) log_sigmas = np.log(sigmas) diff --git a/src/diffusers/schedulers/scheduling_heun_discrete.py b/src/diffusers/schedulers/scheduling_heun_discrete.py index 100e2012ea20..e734ef4b91e8 100644 --- a/src/diffusers/schedulers/scheduling_heun_discrete.py +++ b/src/diffusers/schedulers/scheduling_heun_discrete.py @@ -78,6 +78,13 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. + timestep_spacing (`str`, default `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -93,6 +100,8 @@ def __init__( trained_betas: Optional[Union[np.ndarray, List[float]]] = None, prediction_type: str = "epsilon", use_karras_sigmas: Optional[bool] = False, + timestep_spacing: str = "linspace", + steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -128,6 +137,14 @@ def index_for_timestep(self, timestep, schedule_timesteps=None): pos = 0 return indices[pos].item() + @property + def init_noise_sigma(self): + # standard deviation of the initial noise distribution + if self.config.timestep_spacing == "linspace": + return self.sigmas.max() + + return (self.sigmas.max() ** 2 + 1) ** 0.5 + def scale_model_input( self, sample: torch.FloatTensor, @@ -166,7 +183,25 @@ def set_timesteps( num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps - timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + elif self.config.timestep_spacing == "leading": + step_ratio = num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(float) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) log_sigmas = np.log(sigmas) @@ -180,9 +215,6 @@ def set_timesteps( sigmas = torch.from_numpy(sigmas).to(device=device) self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]]) - # standard deviation of the initial noise distribution - self.init_noise_sigma = self.sigmas.max() - timesteps = torch.from_numpy(timesteps) timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)]) diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py index 2fa0431e1292..f7591352d193 100644 --- a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py @@ -78,6 +78,13 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) + timestep_spacing (`str`, default `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -92,6 +99,8 @@ def __init__( beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, prediction_type: str = "epsilon", + timestep_spacing: str = "linspace", + steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -127,6 +136,14 @@ def index_for_timestep(self, timestep, schedule_timesteps=None): pos = 0 return indices[pos].item() + @property + def init_noise_sigma(self): + # standard deviation of the initial noise distribution + if self.config.timestep_spacing == "linspace": + return self.sigmas.max() + + return (self.sigmas.max() ** 2 + 1) ** 0.5 + def scale_model_input( self, sample: torch.FloatTensor, @@ -169,7 +186,25 @@ def set_timesteps( num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps - timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + elif self.config.timestep_spacing == "leading": + step_ratio = num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(float) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) self.log_sigmas = torch.from_numpy(np.log(sigmas)).to(device) @@ -197,9 +232,6 @@ def set_timesteps( self.sigmas_up = torch.cat([sigmas_up[:1], sigmas_up[1:].repeat_interleave(2), sigmas_up[-1:]]) self.sigmas_down = torch.cat([sigmas_down[:1], sigmas_down[1:].repeat_interleave(2), sigmas_down[-1:]]) - # standard deviation of the initial noise distribution - self.init_noise_sigma = self.sigmas.max() - if str(device).startswith("mps"): # mps does not support float64 timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32) diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py index bb80c4a54bfe..0b7d9c7967ab 100644 --- a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py @@ -77,6 +77,13 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) + timestep_spacing (`str`, default `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -91,6 +98,8 @@ def __init__( beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, prediction_type: str = "epsilon", + timestep_spacing: str = "linspace", + steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -126,6 +135,14 @@ def index_for_timestep(self, timestep, schedule_timesteps=None): pos = 0 return indices[pos].item() + @property + def init_noise_sigma(self): + # standard deviation of the initial noise distribution + if self.config.timestep_spacing == "linspace": + return self.sigmas.max() + + return (self.sigmas.max() ** 2 + 1) ** 0.5 + def scale_model_input( self, sample: torch.FloatTensor, @@ -168,7 +185,25 @@ def set_timesteps( num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps - timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + elif self.config.timestep_spacing == "leading": + step_ratio = num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(float) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) self.log_sigmas = torch.from_numpy(np.log(sigmas)).to(device) @@ -185,9 +220,6 @@ def set_timesteps( [sigmas_interpol[:1], sigmas_interpol[1:].repeat_interleave(2), sigmas_interpol[-1:]] ) - # standard deviation of the initial noise distribution - self.init_noise_sigma = self.sigmas.max() - if str(device).startswith("mps"): # mps does not support float64 timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32) From 5e3c4f7fa2374f19841a46d0387b19e67233836c Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 3 Jul 2023 16:30:13 +0200 Subject: [PATCH 04/30] Add `linspace` to DDIM --- src/diffusers/schedulers/scheduling_ddim.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 3ffa58c55f2d..df5b72bde3c9 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -302,8 +302,12 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic self.num_inference_steps = num_inference_steps - # "leading" and "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 - if self.config.timestep_spacing == "leading": + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps).round()[ + ::-1 + ].copy().astype(np.int64) + elif self.config.timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 From c1e93d2dd3cbc7f83e337ccf5b1801147c82b9a8 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 3 Jul 2023 17:36:37 +0200 Subject: [PATCH 05/30] Noise sigma for `trailing`. --- .../schedulers/scheduling_euler_ancestral_discrete.py | 2 +- src/diffusers/schedulers/scheduling_euler_discrete.py | 8 ++++++++ src/diffusers/schedulers/scheduling_heun_discrete.py | 2 +- .../schedulers/scheduling_k_dpm_2_ancestral_discrete.py | 2 +- src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py | 2 +- src/diffusers/schedulers/scheduling_lms_discrete.py | 2 +- 6 files changed, 13 insertions(+), 5 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py index 5ec6118c6405..01c22dc428f5 100644 --- a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py @@ -155,7 +155,7 @@ def __init__( @property def init_noise_sigma(self): # standard deviation of the initial noise distribution - if self.config.timestep_spacing == "linspace": + if self.config.timestep_spacing in ["linspace", "trailing"]: return self.sigmas.max() return (self.sigmas.max() ** 2 + 1) ** 0.5 diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index af62fdfac259..ca1e727fc95c 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -156,6 +156,14 @@ def __init__( self.is_scale_input_called = False self.use_karras_sigmas = use_karras_sigmas + @property + def init_noise_sigma(self): + # standard deviation of the initial noise distribution + if self.config.timestep_spacing in ["linspace", "trailing"]: + return self.sigmas.max() + + return (self.sigmas.max() ** 2 + 1) ** 0.5 + def scale_model_input( self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] ) -> torch.FloatTensor: diff --git a/src/diffusers/schedulers/scheduling_heun_discrete.py b/src/diffusers/schedulers/scheduling_heun_discrete.py index e734ef4b91e8..28f29067a544 100644 --- a/src/diffusers/schedulers/scheduling_heun_discrete.py +++ b/src/diffusers/schedulers/scheduling_heun_discrete.py @@ -140,7 +140,7 @@ def index_for_timestep(self, timestep, schedule_timesteps=None): @property def init_noise_sigma(self): # standard deviation of the initial noise distribution - if self.config.timestep_spacing == "linspace": + if self.config.timestep_spacing in ["linspace", "trailing"]: return self.sigmas.max() return (self.sigmas.max() ** 2 + 1) ** 0.5 diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py index f7591352d193..d4a35ab82502 100644 --- a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py @@ -139,7 +139,7 @@ def index_for_timestep(self, timestep, schedule_timesteps=None): @property def init_noise_sigma(self): # standard deviation of the initial noise distribution - if self.config.timestep_spacing == "linspace": + if self.config.timestep_spacing in ["linspace", "trailing"]: return self.sigmas.max() return (self.sigmas.max() ** 2 + 1) ** 0.5 diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py index 0b7d9c7967ab..39079fde10d2 100644 --- a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py @@ -138,7 +138,7 @@ def index_for_timestep(self, timestep, schedule_timesteps=None): @property def init_noise_sigma(self): # standard deviation of the initial noise distribution - if self.config.timestep_spacing == "linspace": + if self.config.timestep_spacing in ["linspace", "trailing"]: return self.sigmas.max() return (self.sigmas.max() ** 2 + 1) ** 0.5 diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index c8fda1df8021..cfa3fe09d273 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -159,7 +159,7 @@ def __init__( @property def init_noise_sigma(self): # standard deviation of the initial noise distribution - if self.config.timestep_spacing == "linspace": + if self.config.timestep_spacing in ["linspace", "trailing"]: return self.sigmas.max() return (self.sigmas.max() ** 2 + 1) ** 0.5 From cd4cc28c353ab9c8cfab6b70fcb4a7be18ada1a9 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 3 Jul 2023 18:48:09 +0200 Subject: [PATCH 06/30] Add timestep_spacing to DEISMultistepScheduler. Not sure the range is the way it was intended. --- .../schedulers/scheduling_deis_multistep.py | 50 ++++++++++++++++--- 1 file changed, 44 insertions(+), 6 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_deis_multistep.py b/src/diffusers/schedulers/scheduling_deis_multistep.py index 56c362018c18..5c225e34f611 100644 --- a/src/diffusers/schedulers/scheduling_deis_multistep.py +++ b/src/diffusers/schedulers/scheduling_deis_multistep.py @@ -107,6 +107,13 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. + timestep_spacing (`str`, default `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -129,6 +136,8 @@ def __init__( solver_type: str = "logrho", lower_order_final: bool = True, use_karras_sigmas: Optional[bool] = False, + timestep_spacing: str = "linspace", + steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -185,12 +194,41 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic device (`str` or `torch.device`, optional): the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ - timesteps = ( - np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) - .round()[::-1][:-1] - .copy() - .astype(np.int64) - ) + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + self.config.timestep_spacing = "trailing" + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) + .round()[::-1][:-1] + .copy() + .astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // (num_inference_steps + 1) + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = ( + (np.arange(0, num_inference_steps + 1) * step_ratio) + .round()[::-1][:-1] + .copy() + .astype(np.int64) + ) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = ( + np.arange(self.config.num_train_timesteps, 0, -step_ratio) + .round() + .copy() + .astype(np.int64) + ) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) if self.config.use_karras_sigmas: From 02cdd8777758f3eabb66cce684d46d218ba7f962 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 3 Jul 2023 20:06:17 +0200 Subject: [PATCH 07/30] Fix: remove line used to debug. --- src/diffusers/schedulers/scheduling_deis_multistep.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_deis_multistep.py b/src/diffusers/schedulers/scheduling_deis_multistep.py index 5c225e34f611..8cc6a0240a28 100644 --- a/src/diffusers/schedulers/scheduling_deis_multistep.py +++ b/src/diffusers/schedulers/scheduling_deis_multistep.py @@ -195,7 +195,6 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 - self.config.timestep_spacing = "trailing" if self.config.timestep_spacing == "linspace": timesteps = ( np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) From 30e3a5db5e3997718e6f6af66a4554b0773ef0ac Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 3 Jul 2023 20:19:32 +0200 Subject: [PATCH 08/30] Support timestep_spacing in DPMSolverMultistep, DPMSolverSDE, UniPC --- .../scheduling_dpmsolver_multistep.py | 42 ++++++++++++-- .../schedulers/scheduling_dpmsolver_sde.py | 40 +++++++++++-- .../schedulers/scheduling_unipc_multistep.py | 56 +++++++++++++++++-- 3 files changed, 122 insertions(+), 16 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index d7c29d5488a5..b5c62a8513c5 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -217,12 +217,42 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc # Clipping the minimum of all lambda(t) for numerical stability. # This is critical for cosine (squaredcos_cap_v2) noise schedule. clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped) - timesteps = ( - np.linspace(0, self.config.num_train_timesteps - 1 - clipped_idx, num_inference_steps + 1) - .round()[::-1][:-1] - .copy() - .astype(np.int64) - ) + last_timestep = self.config.num_train_timesteps - clipped_idx + + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, last_timestep - 1, num_inference_steps + 1) + .round()[::-1][:-1] + .copy() + .astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = last_timestep // (num_inference_steps + 1) + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = ( + (np.arange(0, num_inference_steps + 1) * step_ratio) + .round()[::-1][:-1] + .copy() + .astype(np.int64) + ) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = ( + np.arange(last_timestep, 0, -step_ratio) + .round() + .copy() + .astype(np.int64) + ) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) if self.config.use_karras_sigmas: diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py index ae9229981152..da8b71788b75 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py @@ -133,6 +133,13 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. noise_sampler_seed (`int`, *optional*, defaults to `None`): The random seed to use for the noise sampler. If `None`, a random seed will be generated. + timestep_spacing (`str`, default `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -149,6 +156,8 @@ def __init__( prediction_type: str = "epsilon", use_karras_sigmas: Optional[bool] = False, noise_sampler_seed: Optional[int] = None, + timestep_spacing: str = "linspace", + steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -187,6 +196,14 @@ def index_for_timestep(self, timestep, schedule_timesteps=None): pos = 0 return indices[pos].item() + @property + def init_noise_sigma(self): + # standard deviation of the initial noise distribution + if self.config.timestep_spacing in ["linspace", "trailing"]: + return self.sigmas.max() + + return (self.sigmas.max() ** 2 + 1) ** 0.5 + def scale_model_input( self, sample: torch.FloatTensor, @@ -226,7 +243,25 @@ def set_timesteps( num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps - timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + elif self.config.timestep_spacing == "leading": + step_ratio = num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(float) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) log_sigmas = np.log(sigmas) @@ -242,9 +277,6 @@ def set_timesteps( sigmas = torch.from_numpy(sigmas).to(device=device) self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]]) - # standard deviation of the initial noise distribution - self.init_noise_sigma = self.sigmas.max() - timesteps = torch.from_numpy(timesteps) second_order_timesteps = torch.from_numpy(second_order_timesteps) timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)]) diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py index 7233258a4766..d10488b4184b 100644 --- a/src/diffusers/schedulers/scheduling_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -121,6 +121,20 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. + timestep_spacing (`str`, default `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. + timestep_spacing (`str`, default `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -145,6 +159,8 @@ def __init__( disable_corrector: List[int] = [], solver_p: SchedulerMixin = None, use_karras_sigmas: Optional[bool] = False, + timestep_spacing: str = "linspace", + steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -199,12 +215,40 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic device (`str` or `torch.device`, optional): the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ - timesteps = ( - np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) - .round()[::-1][:-1] - .copy() - .astype(np.int64) - ) + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) + .round()[::-1][:-1] + .copy() + .astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // (num_inference_steps + 1) + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = ( + (np.arange(0, num_inference_steps + 1) * step_ratio) + .round()[::-1][:-1] + .copy() + .astype(np.int64) + ) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = ( + np.arange(self.config.num_train_timesteps, 0, -step_ratio) + .round() + .copy() + .astype(np.int64) + ) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) if self.config.use_karras_sigmas: From 8375d0af72a9c8e59ed9160afac21abb16846fe1 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 4 Jul 2023 00:19:04 +0200 Subject: [PATCH 09/30] Fix: convert to numpy. --- src/diffusers/schedulers/scheduling_dpmsolver_multistep.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index b5c62a8513c5..ca3f99ecaf8e 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -217,7 +217,7 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc # Clipping the minimum of all lambda(t) for numerical stability. # This is critical for cosine (squaredcos_cap_v2) noise schedule. clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped) - last_timestep = self.config.num_train_timesteps - clipped_idx + last_timestep = ((self.config.num_train_timesteps - clipped_idx).numpy()).item() # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": From eb60db4b6d209cd785af2c20cb342a63caf17a24 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 4 Jul 2023 01:14:10 +0200 Subject: [PATCH 10/30] Use sched. defaults when instantiating from_config For params not present in the original configuration. This makes it possible to switch pipeline schedulers even if they use different timestep_spacing (or any other param). --- src/diffusers/configuration_utils.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 1a030e467134..59ade8da4404 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -423,6 +423,10 @@ def _get_init_keys(cls): @classmethod def extract_init_dict(cls, config_dict, **kwargs): + # Skip keys that were not present in the original config, so default __init__ values were used + used_defaults = config_dict.get("_use_default_values", []) + config_dict = {k: v for k, v in config_dict.items() if k not in used_defaults and k != "_use_default_values"} + # 0. Copy origin config dict original_dict = dict(config_dict.items()) @@ -450,6 +454,10 @@ def extract_init_dict(cls, config_dict, **kwargs): else: compatible_classes = [] + # Keys not present in the config - default values were used + # TODO: remove the ones passed in kwargs? + used_default_keys = set(expected_keys) - set(original_dict.keys()) + expected_keys_comp_cls = set() for c in compatible_classes: expected_keys_c = cls._get_init_keys(c) @@ -503,6 +511,9 @@ def extract_init_dict(cls, config_dict, **kwargs): # 7. Define "hidden" config parameters that were saved for compatible classes hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict} + # 8. Take note of the parameters that were not present in the loaded config + hidden_config_dict["_use_default_values"] = used_default_keys + return init_dict, unused_kwargs, hidden_config_dict @classmethod From 81c5a1e9a89c28c54913ad14e6d73100bd27585c Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 4 Jul 2023 12:02:03 +0200 Subject: [PATCH 11/30] Apply suggestions from code review Co-authored-by: Patrick von Platen --- src/diffusers/configuration_utils.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 59ade8da4404..81509384e79a 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -454,10 +454,6 @@ def extract_init_dict(cls, config_dict, **kwargs): else: compatible_classes = [] - # Keys not present in the config - default values were used - # TODO: remove the ones passed in kwargs? - used_default_keys = set(expected_keys) - set(original_dict.keys()) - expected_keys_comp_cls = set() for c in compatible_classes: expected_keys_c = cls._get_init_keys(c) @@ -512,7 +508,7 @@ def extract_init_dict(cls, config_dict, **kwargs): hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict} # 8. Take note of the parameters that were not present in the loaded config - hidden_config_dict["_use_default_values"] = used_default_keys + hidden_config_dict["_use_default_values"] = expected_keys - set(init_dict) return init_dict, unused_kwargs, hidden_config_dict From a38021fc4f2f4dba6f5582e9ac385baceb940f97 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 4 Jul 2023 12:27:30 +0200 Subject: [PATCH 12/30] Missing args in DPMSolverMultistep --- .../schedulers/scheduling_dpmsolver_multistep.py | 9 +++++++++ .../schedulers/scheduling_euler_ancestral_discrete.py | 1 - 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index ca3f99ecaf8e..a323d69a1e01 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -134,6 +134,13 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on diffusion ODEs. + timestep_spacing (`str`, default `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -158,6 +165,8 @@ def __init__( use_karras_sigmas: Optional[bool] = False, lambda_min_clipped: float = -float("inf"), variance_type: Optional[str] = None, + timestep_spacing: str = "linspace", + steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) diff --git a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py index 01c22dc428f5..132729b7eb88 100644 --- a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py @@ -106,7 +106,6 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. - """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] From 8f9c73c1d4aab105f82501516cdfb4584878e13d Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 4 Jul 2023 12:28:19 +0200 Subject: [PATCH 13/30] Test: default args not in config --- tests/schedulers/test_schedulers.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/tests/schedulers/test_schedulers.py b/tests/schedulers/test_schedulers.py index a2d065f388bd..b2df100b5d3e 100755 --- a/tests/schedulers/test_schedulers.py +++ b/tests/schedulers/test_schedulers.py @@ -24,6 +24,8 @@ import diffusers from diffusers import ( + DDIMScheduler, + DiffusionPipeline, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, IPNDMScheduler, @@ -202,6 +204,31 @@ def test_save_load_from_different_config_comp_schedulers(self): assert cap_logger_2.out == "{'f'} was not found in config. Values will be initialized to default values.\n" assert cap_logger_3.out == "{'f'} was not found in config. Values will be initialized to default values.\n" + def test_default_arguments_not_in_config(self): + pipe = DiffusionPipeline.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-pipe", torch_dtype=torch.float16 + ) + assert pipe.scheduler.__class__ == DDIMScheduler + + # Default for PNDMScheduler + assert pipe.scheduler.config.timestep_spacing == "leading" + + # Switch to a different one, verify we use the default for that class + pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config) + assert pipe.scheduler.config.timestep_spacing == "linspace" + + # Override with kwargs + pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") + assert pipe.scheduler.config.timestep_spacing == "trailing" + + # Verify overridden kwargs stick + pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config) + assert pipe.scheduler.config.timestep_spacing == "trailing" + + # And stick + pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config) + assert pipe.scheduler.config.timestep_spacing == "trailing" + class SchedulerCommonTest(unittest.TestCase): scheduler_classes = () From d833276caf7fec3161f7b9d5dbd1bd5ca0f3652a Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 4 Jul 2023 12:29:26 +0200 Subject: [PATCH 14/30] Style --- src/diffusers/configuration_utils.py | 2 +- src/diffusers/schedulers/scheduling_ddim.py | 9 ++++++--- src/diffusers/schedulers/scheduling_ddpm.py | 10 ++++++---- .../schedulers/scheduling_deis_multistep.py | 14 ++------------ .../scheduling_dpmsolver_multistep.py | 19 +++---------------- .../scheduling_euler_ancestral_discrete.py | 4 +++- .../schedulers/scheduling_euler_discrete.py | 4 +++- .../schedulers/scheduling_lms_discrete.py | 4 +++- src/diffusers/schedulers/scheduling_pndm.py | 8 ++++++-- .../schedulers/scheduling_unipc_multistep.py | 14 ++------------ 10 files changed, 35 insertions(+), 53 deletions(-) diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 81509384e79a..33bc32f6b743 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -508,7 +508,7 @@ def extract_init_dict(cls, config_dict, **kwargs): hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict} # 8. Take note of the parameters that were not present in the loaded config - hidden_config_dict["_use_default_values"] = expected_keys - set(init_dict) + hidden_config_dict["_use_default_values"] = expected_keys - set(init_dict) return init_dict, unused_kwargs, hidden_config_dict diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index df5b72bde3c9..99602d14038b 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -304,9 +304,12 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": - timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps).round()[ - ::-1 - ].copy().astype(np.int64) + timesteps = ( + np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps) + .round()[::-1] + .copy() + .astype(np.int64) + ) elif self.config.timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 43d252f1c3e8..930e14108a67 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -241,9 +241,12 @@ def set_timesteps( # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": - timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps).round()[ - ::-1 - ].copy().astype(np.int64) + timesteps = ( + np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps) + .round()[::-1] + .copy() + .astype(np.int64) + ) elif self.config.timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio @@ -261,7 +264,6 @@ def set_timesteps( f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." ) - self.timesteps = torch.from_numpy(timesteps).to(device) def _get_variance(self, t, predicted_variance=None, variance_type=None): diff --git a/src/diffusers/schedulers/scheduling_deis_multistep.py b/src/diffusers/schedulers/scheduling_deis_multistep.py index 8cc6a0240a28..c504fb19231a 100644 --- a/src/diffusers/schedulers/scheduling_deis_multistep.py +++ b/src/diffusers/schedulers/scheduling_deis_multistep.py @@ -206,23 +206,13 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic step_ratio = self.config.num_train_timesteps // (num_inference_steps + 1) # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = ( - (np.arange(0, num_inference_steps + 1) * step_ratio) - .round()[::-1][:-1] - .copy() - .astype(np.int64) - ) + timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = self.config.num_train_timesteps / num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = ( - np.arange(self.config.num_train_timesteps, 0, -step_ratio) - .round() - .copy() - .astype(np.int64) - ) + timesteps = np.arange(self.config.num_train_timesteps, 0, -step_ratio).round().copy().astype(np.int64) timesteps -= 1 else: raise ValueError( diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index a323d69a1e01..528b7b838b1c 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -231,32 +231,19 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": timesteps = ( - np.linspace(0, last_timestep - 1, num_inference_steps + 1) - .round()[::-1][:-1] - .copy() - .astype(np.int64) + np.linspace(0, last_timestep - 1, num_inference_steps + 1).round()[::-1][:-1].copy().astype(np.int64) ) elif self.config.timestep_spacing == "leading": step_ratio = last_timestep // (num_inference_steps + 1) # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = ( - (np.arange(0, num_inference_steps + 1) * step_ratio) - .round()[::-1][:-1] - .copy() - .astype(np.int64) - ) + timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = self.config.num_train_timesteps / num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = ( - np.arange(last_timestep, 0, -step_ratio) - .round() - .copy() - .astype(np.int64) - ) + timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.int64) timesteps -= 1 else: raise ValueError( diff --git a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py index 132729b7eb88..6b8c2f1a8a28 100644 --- a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py @@ -194,7 +194,9 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": - timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[ + ::-1 + ].copy() elif self.config.timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index ca1e727fc95c..d70869d3a985 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -201,7 +201,9 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": - timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[ + ::-1 + ].copy() elif self.config.timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index cfa3fe09d273..1256660b843c 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -221,7 +221,9 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": - timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[ + ::-1 + ].copy() elif self.config.timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index 7ce348c7e2d0..70ee1301129c 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -164,7 +164,9 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic self.num_inference_steps = num_inference_steps # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": - self._timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps).round().astype(np.int64) + self._timesteps = ( + np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps).round().astype(np.int64) + ) elif self.config.timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio @@ -175,7 +177,9 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic step_ratio = self.config.num_train_timesteps / self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 - self._timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio))[::-1].astype(np.int64) + self._timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio))[::-1].astype( + np.int64 + ) self._timesteps -= 1 else: raise ValueError( diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py index d10488b4184b..48b5af7aa2cc 100644 --- a/src/diffusers/schedulers/scheduling_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -227,23 +227,13 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic step_ratio = self.config.num_train_timesteps // (num_inference_steps + 1) # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = ( - (np.arange(0, num_inference_steps + 1) * step_ratio) - .round()[::-1][:-1] - .copy() - .astype(np.int64) - ) + timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = self.config.num_train_timesteps / num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = ( - np.arange(self.config.num_train_timesteps, 0, -step_ratio) - .round() - .copy() - .astype(np.int64) - ) + timesteps = np.arange(self.config.num_train_timesteps, 0, -step_ratio).round().copy().astype(np.int64) timesteps -= 1 else: raise ValueError( From 3de3909375cebdfb77add20b4296cb318341bea5 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 4 Jul 2023 12:35:04 +0200 Subject: [PATCH 15/30] Fix scheduler name in test --- tests/schedulers/test_schedulers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/schedulers/test_schedulers.py b/tests/schedulers/test_schedulers.py index b2df100b5d3e..cecc5cd00499 100755 --- a/tests/schedulers/test_schedulers.py +++ b/tests/schedulers/test_schedulers.py @@ -210,7 +210,7 @@ def test_default_arguments_not_in_config(self): ) assert pipe.scheduler.__class__ == DDIMScheduler - # Default for PNDMScheduler + # Default for DDIMScheduler assert pipe.scheduler.config.timestep_spacing == "leading" # Switch to a different one, verify we use the default for that class From 5080d7faa1c724b12d91da442bd30aedb6d4a16c Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 4 Jul 2023 13:25:00 +0200 Subject: [PATCH 16/30] Remove duplicated entries --- src/diffusers/schedulers/scheduling_unipc_multistep.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py index 48b5af7aa2cc..a1b072632fd0 100644 --- a/src/diffusers/schedulers/scheduling_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -128,13 +128,6 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. - timestep_spacing (`str`, default `"linspace"`): - The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample - Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. - steps_offset (`int`, default `0`): - an offset added to the inference steps. You can use a combination of `offset=1` and - `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in - stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] From 083c8aab9695ab076a556bb69af3351c5a78ebca Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 4 Jul 2023 13:30:16 +0200 Subject: [PATCH 17/30] Add test for solver_type This test currently fails in main. When switching from DEIS to UniPC, solver_type is "logrho" (the default value from DEIS), which gets translated to "bh1" by UniPC. This is different to the default value for UniPC: "bh2". This is where the translation happens: https://github.com/huggingface/diffusers/blob/36d22d0709dc19776e3016fb3392d0f5578b0ab2/src/diffusers/schedulers/scheduling_unipc_multistep.py#L171 --- tests/schedulers/test_schedulers.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/schedulers/test_schedulers.py b/tests/schedulers/test_schedulers.py index cecc5cd00499..2554c0d1ec15 100755 --- a/tests/schedulers/test_schedulers.py +++ b/tests/schedulers/test_schedulers.py @@ -25,11 +25,13 @@ import diffusers from diffusers import ( DDIMScheduler, + DEISMultistepScheduler, DiffusionPipeline, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, IPNDMScheduler, LMSDiscreteScheduler, + UniPCMultistepScheduler, VQDiffusionScheduler, logging, ) @@ -229,6 +231,19 @@ def test_default_arguments_not_in_config(self): pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config) assert pipe.scheduler.config.timestep_spacing == "trailing" + def test_default_solver_type_after_switch(self): + pipe = DiffusionPipeline.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-pipe", torch_dtype=torch.float16 + ) + assert pipe.scheduler.__class__ == DDIMScheduler + + pipe.scheduler = DEISMultistepScheduler.from_config(pipe.scheduler.config) + assert pipe.scheduler.config.solver_type == "logrho" + + # Switch to UniPC, verify the solver is the default + pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) + assert pipe.scheduler.config.solver_type == "bh2" + class SchedulerCommonTest(unittest.TestCase): scheduler_classes = () From bef7669b1b84c8388d511f0e0941e64f82087b7e Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 4 Jul 2023 14:46:55 +0200 Subject: [PATCH 18/30] UniPC: use same default for solver_type Fixes a bug when switching from UniPC from another scheduler (i.e., DEIS) that uses a different solver type. The solver is now the same as if we had instantiated the scheduler directly. --- src/diffusers/schedulers/scheduling_unipc_multistep.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py index a1b072632fd0..3caa01a58562 100644 --- a/src/diffusers/schedulers/scheduling_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -182,7 +182,7 @@ def __init__( if solver_type not in ["bh1", "bh2"]: if solver_type in ["midpoint", "heun", "logrho"]: - self.register_to_config(solver_type="bh1") + self.register_to_config(solver_type="bh2") else: raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") From 21d7eea556f7e4c17146c1e14bef46e98155da4e Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 5 Jul 2023 11:45:57 +0200 Subject: [PATCH 19/30] do not save use default values --- src/diffusers/configuration_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 33bc32f6b743..9a38da711426 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -551,8 +551,9 @@ def to_json_saveable(value): return value config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()} - # Don't save "_ignore_files" + # Don't save "_ignore_files" or "_use_default_values" config_dict.pop("_ignore_files", None) + config_dict.pop("_use_default_values", None) return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" From cb1109f06fd56c86c47ac0bb98f27b213cf4f5e7 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 5 Jul 2023 12:15:52 +0200 Subject: [PATCH 20/30] fix more --- src/diffusers/configuration_utils.py | 3 +- .../schedulers/scheduling_euler_discrete.py | 3 -- tests/others/test_config.py | 34 +++++++++++++++++++ 3 files changed, 36 insertions(+), 4 deletions(-) diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 9a38da711426..aaffd9fd1569 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -508,7 +508,8 @@ def extract_init_dict(cls, config_dict, **kwargs): hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict} # 8. Take note of the parameters that were not present in the loaded config - hidden_config_dict["_use_default_values"] = expected_keys - set(init_dict) + if len(expected_keys - set(init_dict)) > 0: + hidden_config_dict["_use_default_values"] = expected_keys - set(init_dict) return init_dict, unused_kwargs, hidden_config_dict diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index d70869d3a985..c780d1aceb07 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -146,9 +146,6 @@ def __init__( sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) - # standard deviation of the initial noise distribution - self.init_noise_sigma = self.sigmas.max() - # setable values self.num_inference_steps = None timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() diff --git a/tests/others/test_config.py b/tests/others/test_config.py index a29190c199ca..eb64aca2ba4f 100644 --- a/tests/others/test_config.py +++ b/tests/others/test_config.py @@ -75,6 +75,22 @@ def __init__( pass +class SampleObject4(ConfigMixin): + config_name = "config.json" + + @register_to_config + def __init__( + self, + a=2, + b=5, + c=(2, 5), + d="for diffusion", + e=[1, 5], + f=[5, 4], + ): + pass + + class ConfigTester(unittest.TestCase): def test_load_not_from_mixin(self): with self.assertRaises(ValueError): @@ -233,3 +249,21 @@ def test_load_dpmsolver(self): assert dpm.__class__ == DPMSolverMultistepScheduler # no warning should be thrown assert cap_logger.out == "" + + def test_use_default_values(self): + # let's first save a config that should be in the form + # a=2, + # b=5, + # c=(2, 5), + # d="for diffusion", + # e=[1, 3], + + config = SampleObject() + with tempfile.TemporaryDirectory() as tmpdirname: + config.save(tmpdirname) + + # now loading it with SampleObject2 should put f into `_use_default_values` + config = SampleObject2.from_config(tmpdirname) + + assert "f" in config._use_default_values + assert config.f == [1, 4] From 46c5c8f5c5e3c84a264091dddb823da76c6998ec Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 5 Jul 2023 12:36:11 +0200 Subject: [PATCH 21/30] fix all --- .../schedulers/scheduling_euler_discrete.py | 9 +++++++++ tests/others/test_config.py | 16 ++++++++++++++-- tests/schedulers/test_scheduler_euler.py | 4 ++-- .../schedulers/test_scheduler_euler_ancestral.py | 4 ++-- tests/schedulers/test_scheduler_lms.py | 2 +- 5 files changed, 28 insertions(+), 7 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index c780d1aceb07..fc52c50ebc7f 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -107,6 +107,13 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. + timestep_spacing (`str`, default `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -123,6 +130,8 @@ def __init__( prediction_type: str = "epsilon", interpolation_type: str = "linear", use_karras_sigmas: Optional[bool] = False, + timestep_spacing: str = "linspace", + steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) diff --git a/tests/others/test_config.py b/tests/others/test_config.py index eb64aca2ba4f..ba2ae439080f 100644 --- a/tests/others/test_config.py +++ b/tests/others/test_config.py @@ -260,10 +260,22 @@ def test_use_default_values(self): config = SampleObject() with tempfile.TemporaryDirectory() as tmpdirname: - config.save(tmpdirname) + config.save_config(tmpdirname) # now loading it with SampleObject2 should put f into `_use_default_values` config = SampleObject2.from_config(tmpdirname) assert "f" in config._use_default_values - assert config.f == [1, 4] + assert config.f == [1, 3] + + # now loading the config, should **NOT** use [1, 3] for `f`, but the default [1, 4] value + # **BECAUSE** it is part of `config._use_default_values` + new_config = SampleObject4.from_config(config.config) + assert new_config.f == [5, 4] + + config.config._use_default_values.pop() + new_config_2 = SampleObject4.from_config(config.config) + assert new_config_2.f == [1, 3] + + # Nevertheless "e" should still be correctly loaded to [1, 3] from SampleObject2 instead of defaulting to [1, 5] + assert new_config_2.e == [1, 3] diff --git a/tests/schedulers/test_scheduler_euler.py b/tests/schedulers/test_scheduler_euler.py index aa46ef31885a..0c3b065161db 100644 --- a/tests/schedulers/test_scheduler_euler.py +++ b/tests/schedulers/test_scheduler_euler.py @@ -101,7 +101,7 @@ def test_full_loop_device(self): generator = torch.manual_seed(0) model = self.dummy_model() - sample = self.dummy_sample_deter * scheduler.init_noise_sigma + sample = self.dummy_sample_deter * scheduler.init_noise_sigma.cpu() sample = sample.to(torch_device) for t in scheduler.timesteps: @@ -128,7 +128,7 @@ def test_full_loop_device_karras_sigmas(self): generator = torch.manual_seed(0) model = self.dummy_model() - sample = self.dummy_sample_deter * scheduler.init_noise_sigma + sample = self.dummy_sample_deter * scheduler.init_noise_sigma.cpu() sample = sample.to(torch_device) for t in scheduler.timesteps: diff --git a/tests/schedulers/test_scheduler_euler_ancestral.py b/tests/schedulers/test_scheduler_euler_ancestral.py index 5fa36be6bc64..9866bd12d6af 100644 --- a/tests/schedulers/test_scheduler_euler_ancestral.py +++ b/tests/schedulers/test_scheduler_euler_ancestral.py @@ -47,7 +47,7 @@ def test_full_loop_no_noise(self): generator = torch.manual_seed(0) model = self.dummy_model() - sample = self.dummy_sample_deter * scheduler.init_noise_sigma + sample = self.dummy_sample_deter * scheduler.init_noise_sigma.cpu() sample = sample.to(torch_device) for i, t in enumerate(scheduler.timesteps): @@ -100,7 +100,7 @@ def test_full_loop_device(self): generator = torch.manual_seed(0) model = self.dummy_model() - sample = self.dummy_sample_deter * scheduler.init_noise_sigma + sample = self.dummy_sample_deter * scheduler.init_noise_sigma.cpu() sample = sample.to(torch_device) for t in scheduler.timesteps: diff --git a/tests/schedulers/test_scheduler_lms.py b/tests/schedulers/test_scheduler_lms.py index 2682886a788d..1e0a8212354d 100644 --- a/tests/schedulers/test_scheduler_lms.py +++ b/tests/schedulers/test_scheduler_lms.py @@ -97,7 +97,7 @@ def test_full_loop_device(self): scheduler.set_timesteps(self.num_inference_steps, device=torch_device) model = self.dummy_model() - sample = self.dummy_sample_deter * scheduler.init_noise_sigma + sample = self.dummy_sample_deter * scheduler.init_noise_sigma.cpu() sample = sample.to(torch_device) for i, t in enumerate(scheduler.timesteps): From dc8c70b009993f163a633b446355811929aa020a Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 5 Jul 2023 14:22:25 +0200 Subject: [PATCH 22/30] fix schedulers --- tests/schedulers/test_scheduler_unipc.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/schedulers/test_scheduler_unipc.py b/tests/schedulers/test_scheduler_unipc.py index 62cffc67388c..171ee85be1d3 100644 --- a/tests/schedulers/test_scheduler_unipc.py +++ b/tests/schedulers/test_scheduler_unipc.py @@ -23,7 +23,7 @@ def get_scheduler_config(self, **kwargs): "beta_end": 0.02, "beta_schedule": "linear", "solver_order": 2, - "solver_type": "bh1", + "solver_type": "bh2", } config.update(**kwargs) @@ -144,7 +144,7 @@ def test_switch(self): sample = self.full_loop(scheduler=scheduler) result_mean = torch.mean(torch.abs(sample)) - assert abs(result_mean.item() - 0.2521) < 1e-3 + assert abs(result_mean.item() - 0.2464) < 1e-3 scheduler = DPMSolverSinglestepScheduler.from_config(scheduler.config) scheduler = DEISMultistepScheduler.from_config(scheduler.config) @@ -154,7 +154,7 @@ def test_switch(self): sample = self.full_loop(scheduler=scheduler) result_mean = torch.mean(torch.abs(sample)) - assert abs(result_mean.item() - 0.2521) < 1e-3 + assert abs(result_mean.item() - 0.2464) < 1e-3 def test_timesteps(self): for timesteps in [25, 50, 100, 999, 1000]: @@ -206,13 +206,13 @@ def test_full_loop_no_noise(self): sample = self.full_loop() result_mean = torch.mean(torch.abs(sample)) - assert abs(result_mean.item() - 0.2521) < 1e-3 + assert abs(result_mean.item() - 0.2464) < 1e-3 def test_full_loop_with_v_prediction(self): sample = self.full_loop(prediction_type="v_prediction") result_mean = torch.mean(torch.abs(sample)) - assert abs(result_mean.item() - 0.1096) < 1e-3 + assert abs(result_mean.item() - 0.1014) < 1e-3 def test_fp16_support(self): scheduler_class = self.scheduler_classes[0] From 64703530eb7d21eafbd77de0ae5d7b5465b6a62f Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 5 Jul 2023 14:40:36 +0200 Subject: [PATCH 23/30] fix more --- src/diffusers/configuration_utils.py | 6 ++++++ tests/schedulers/test_schedulers.py | 6 +++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index aaffd9fd1569..fc793340e7d2 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -608,6 +608,9 @@ def inner_init(self, *args, **kwargs): if k not in ignore and k not in new_kwargs } ) + if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0: + new_kwargs["_use_default_values"] = set(new_kwargs.keys()) - set(init_kwargs) + new_kwargs = {**config_init_kwargs, **new_kwargs} getattr(self, "register_to_config")(**new_kwargs) init(self, *args, **init_kwargs) @@ -652,6 +655,9 @@ def init(self, *args, **kwargs): name = fields[i].name new_kwargs[name] = arg + if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0: + new_kwargs["_use_default_values"] = set(new_kwargs.keys()) - set(init_kwargs) + getattr(self, "register_to_config")(**new_kwargs) original_init(self, *args, **kwargs) diff --git a/tests/schedulers/test_schedulers.py b/tests/schedulers/test_schedulers.py index 2554c0d1ec15..d1ae333c0cd2 100755 --- a/tests/schedulers/test_schedulers.py +++ b/tests/schedulers/test_schedulers.py @@ -456,7 +456,11 @@ def test_from_pretrained(self): scheduler.save_pretrained(tmpdirname) new_scheduler = scheduler_class.from_pretrained(tmpdirname) - assert scheduler.config == new_scheduler.config + # `_use_default_values` should not exist for just saved & loaded scheduler + scheduler_config = dict(scheduler.config) + del scheduler_config["_use_default_values"] + + assert scheduler_config == new_scheduler.config def test_step_shape(self): kwargs = dict(self.forward_default_kwargs) From 72949cb65f07e022b157ebc49b1c7c80b4b247be Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 5 Jul 2023 14:46:44 +0200 Subject: [PATCH 24/30] finish for real --- src/diffusers/configuration_utils.py | 7 +++---- tests/others/test_config.py | 7 +++++++ 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index fc793340e7d2..202905db52c6 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -507,10 +507,6 @@ def extract_init_dict(cls, config_dict, **kwargs): # 7. Define "hidden" config parameters that were saved for compatible classes hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict} - # 8. Take note of the parameters that were not present in the loaded config - if len(expected_keys - set(init_dict)) > 0: - hidden_config_dict["_use_default_values"] = expected_keys - set(init_dict) - return init_dict, unused_kwargs, hidden_config_dict @classmethod @@ -608,6 +604,8 @@ def inner_init(self, *args, **kwargs): if k not in ignore and k not in new_kwargs } ) + + # Take note of the parameters that were not present in the loaded config if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0: new_kwargs["_use_default_values"] = set(new_kwargs.keys()) - set(init_kwargs) @@ -655,6 +653,7 @@ def init(self, *args, **kwargs): name = fields[i].name new_kwargs[name] = arg + # Take note of the parameters that were not present in the loaded config if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0: new_kwargs["_use_default_values"] = set(new_kwargs.keys()) - set(init_kwargs) diff --git a/tests/others/test_config.py b/tests/others/test_config.py index ba2ae439080f..d1f8a6e054d4 100644 --- a/tests/others/test_config.py +++ b/tests/others/test_config.py @@ -153,6 +153,7 @@ def test_save_load(self): assert config.pop("c") == (2, 5) # instantiated as tuple assert new_config.pop("c") == [2, 5] # saved & loaded as list because of json + config.pop("_use_default_values") assert config == new_config def test_load_ddim_from_pndm(self): @@ -259,6 +260,12 @@ def test_use_default_values(self): # e=[1, 3], config = SampleObject() + + config_dict = {k: v for k, v in config.config.items() if not k.startswith("_")} + + # make sure that default config has all keys in `_use_default_values` + assert set(config_dict.keys()) == config.config._use_default_values + with tempfile.TemporaryDirectory() as tmpdirname: config.save_config(tmpdirname) From 45328ef6447c67c824b09e5331aa9ade57a9bbda Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 5 Jul 2023 14:52:38 +0200 Subject: [PATCH 25/30] finish for real --- .../schedulers/scheduling_ddim_parallel.py | 11 +++++-- .../schedulers/scheduling_ddpm_parallel.py | 30 +++++++++++++++++-- 2 files changed, 36 insertions(+), 5 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddim_parallel.py b/src/diffusers/schedulers/scheduling_ddim_parallel.py index 22b7d8ec97dc..8875aa73208b 100644 --- a/src/diffusers/schedulers/scheduling_ddim_parallel.py +++ b/src/diffusers/schedulers/scheduling_ddim_parallel.py @@ -321,8 +321,15 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic self.num_inference_steps = num_inference_steps - # "leading" and "trailing" corresponds to annotation of Table 1. of https://arxiv.org/abs/2305.08891 - if self.config.timestep_spacing == "leading": + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps) + .round()[::-1] + .copy() + .astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 diff --git a/src/diffusers/schedulers/scheduling_ddpm_parallel.py b/src/diffusers/schedulers/scheduling_ddpm_parallel.py index 2719d90b9314..cbc5364bcbb7 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_parallel.py +++ b/src/diffusers/schedulers/scheduling_ddpm_parallel.py @@ -138,6 +138,8 @@ def __init__( dynamic_thresholding_ratio: float = 0.995, clip_sample_range: float = 1.0, sample_max_value: float = 1.0, + timestep_spacing: str = "leading", + steps_offset: int = 1, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -234,11 +236,33 @@ def set_timesteps( ) self.num_inference_steps = num_inference_steps - - step_ratio = self.config.num_train_timesteps // self.num_inference_steps - timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) self.custom_timesteps = False + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps) + .round()[::-1] + .copy() + .astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + self.timesteps = torch.from_numpy(timesteps).to(device) # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._get_variance From 8f4c443bcfb54d576f74f4b51239aa6477d391af Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 5 Jul 2023 15:10:05 +0200 Subject: [PATCH 26/30] flaky tests --- .../stable_diffusion/test_stable_diffusion_panorama.py | 2 +- .../stable_diffusion/test_stable_diffusion_pix2pix_zero.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py index 32541c980a15..080bd0091f4f 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py @@ -186,7 +186,7 @@ def test_stable_diffusion_panorama_euler(self): assert image.shape == (1, 64, 64, 3) - expected_slice = np.array([0.4886, 0.5586, 0.4476, 0.5053, 0.6013, 0.4737, 0.5538, 0.5100, 0.4927]) + expected_slice = np.array([0.4024, 0.6510, 0.4901, 0.5378, 0.5813, 0.5622, 0.4795, 0.4467, 0.4952]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py index 1b17f8b31be9..eec8343cb15b 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py @@ -304,7 +304,8 @@ def test_stable_diffusion_pix2pix_zero_ddpm(self): image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 64, 64, 3) - expected_slice = np.array([0.4861, 0.5053, 0.5038, 0.3994, 0.3562, 0.4768, 0.5172, 0.5280, 0.4938]) + print(torch.from_numpy(image_slice.flatten())) + expected_slice = np.array([0.4833, 0.5052, 0.5034, 0.4022, 0.3577, 0.4766, 0.5176, 0.5288, 0.4942]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 From cedcb1b449f70f551327de99021e5b3bba34260d Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 5 Jul 2023 15:22:45 +0200 Subject: [PATCH 27/30] Update tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py --- .../stable_diffusion/test_stable_diffusion_pix2pix_zero.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py index eec8343cb15b..7d8613777554 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py @@ -304,7 +304,6 @@ def test_stable_diffusion_pix2pix_zero_ddpm(self): image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 64, 64, 3) - print(torch.from_numpy(image_slice.flatten())) expected_slice = np.array([0.4833, 0.5052, 0.5034, 0.4022, 0.3577, 0.4766, 0.5176, 0.5288, 0.4942]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 From a8bdf0658557aee5dd7b43e263a0f8bb5de6a43f Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Wed, 5 Jul 2023 15:32:21 +0200 Subject: [PATCH 28/30] Default steps_offset to 0. --- src/diffusers/schedulers/scheduling_ddpm.py | 2 +- src/diffusers/schedulers/scheduling_ddpm_parallel.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 930e14108a67..ddf27d409d88 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -142,7 +142,7 @@ def __init__( clip_sample_range: float = 1.0, sample_max_value: float = 1.0, timestep_spacing: str = "leading", - steps_offset: int = 1, + steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) diff --git a/src/diffusers/schedulers/scheduling_ddpm_parallel.py b/src/diffusers/schedulers/scheduling_ddpm_parallel.py index cbc5364bcbb7..a169a16fd6e9 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_parallel.py +++ b/src/diffusers/schedulers/scheduling_ddpm_parallel.py @@ -139,7 +139,7 @@ def __init__( clip_sample_range: float = 1.0, sample_max_value: float = 1.0, timestep_spacing: str = "leading", - steps_offset: int = 1, + steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) From f5a1d2c479ebf8f1d774188efe8b22dda8c9ed29 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Wed, 5 Jul 2023 15:32:59 +0200 Subject: [PATCH 29/30] Add missing docstrings --- src/diffusers/schedulers/scheduling_ddpm_parallel.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/diffusers/schedulers/scheduling_ddpm_parallel.py b/src/diffusers/schedulers/scheduling_ddpm_parallel.py index a169a16fd6e9..e4d858efde8f 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_parallel.py +++ b/src/diffusers/schedulers/scheduling_ddpm_parallel.py @@ -116,6 +116,13 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin): (https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`. sample_max_value (`float`, default `1.0`): the threshold value for dynamic thresholding. Valid only when `thresholding=True`. + timestep_spacing (`str`, default `"leading"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] From 910d609bffd6b91884e91732d569f8f70bfb7629 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 5 Jul 2023 15:37:35 +0200 Subject: [PATCH 30/30] Apply suggestions from code review --- .../stable_diffusion/test_stable_diffusion_pix2pix_zero.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py index 7d8613777554..1b17f8b31be9 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py @@ -304,7 +304,7 @@ def test_stable_diffusion_pix2pix_zero_ddpm(self): image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 64, 64, 3) - expected_slice = np.array([0.4833, 0.5052, 0.5034, 0.4022, 0.3577, 0.4766, 0.5176, 0.5288, 0.4942]) + expected_slice = np.array([0.4861, 0.5053, 0.5038, 0.3994, 0.3562, 0.4768, 0.5172, 0.5280, 0.4938]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3