diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 1a030e467134..202905db52c6 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -423,6 +423,10 @@ def _get_init_keys(cls): @classmethod def extract_init_dict(cls, config_dict, **kwargs): + # Skip keys that were not present in the original config, so default __init__ values were used + used_defaults = config_dict.get("_use_default_values", []) + config_dict = {k: v for k, v in config_dict.items() if k not in used_defaults and k != "_use_default_values"} + # 0. Copy origin config dict original_dict = dict(config_dict.items()) @@ -544,8 +548,9 @@ def to_json_saveable(value): return value config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()} - # Don't save "_ignore_files" + # Don't save "_ignore_files" or "_use_default_values" config_dict.pop("_ignore_files", None) + config_dict.pop("_use_default_values", None) return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" @@ -599,6 +604,11 @@ def inner_init(self, *args, **kwargs): if k not in ignore and k not in new_kwargs } ) + + # Take note of the parameters that were not present in the loaded config + if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0: + new_kwargs["_use_default_values"] = set(new_kwargs.keys()) - set(init_kwargs) + new_kwargs = {**config_init_kwargs, **new_kwargs} getattr(self, "register_to_config")(**new_kwargs) init(self, *args, **init_kwargs) @@ -643,6 +653,10 @@ def init(self, *args, **kwargs): name = fields[i].name new_kwargs[name] = arg + # Take note of the parameters that were not present in the loaded config + if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0: + new_kwargs["_use_default_values"] = set(new_kwargs.keys()) - set(init_kwargs) + getattr(self, "register_to_config")(**new_kwargs) original_init(self, *args, **kwargs) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index bab6f8acea03..99602d14038b 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -302,8 +302,15 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic self.num_inference_steps = num_inference_steps - # "leading" and "trailing" corresponds to annotation of Table 1. of https://arxiv.org/abs/2305.08891 - if self.config.timestep_spacing == "leading": + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps) + .round()[::-1] + .copy() + .astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 diff --git a/src/diffusers/schedulers/scheduling_ddim_parallel.py b/src/diffusers/schedulers/scheduling_ddim_parallel.py index 22b7d8ec97dc..8875aa73208b 100644 --- a/src/diffusers/schedulers/scheduling_ddim_parallel.py +++ b/src/diffusers/schedulers/scheduling_ddim_parallel.py @@ -321,8 +321,15 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic self.num_inference_steps = num_inference_steps - # "leading" and "trailing" corresponds to annotation of Table 1. of https://arxiv.org/abs/2305.08891 - if self.config.timestep_spacing == "leading": + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps) + .round()[::-1] + .copy() + .astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 5d24766d68c7..ddf27d409d88 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -114,6 +114,13 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): (https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`. sample_max_value (`float`, default `1.0`): the threshold value for dynamic thresholding. Valid only when `thresholding=True`. + timestep_spacing (`str`, default `"leading"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -134,6 +141,8 @@ def __init__( dynamic_thresholding_ratio: float = 0.995, clip_sample_range: float = 1.0, sample_max_value: float = 1.0, + timestep_spacing: str = "leading", + steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -228,11 +237,33 @@ def set_timesteps( ) self.num_inference_steps = num_inference_steps - - step_ratio = self.config.num_train_timesteps // self.num_inference_steps - timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) self.custom_timesteps = False + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps) + .round()[::-1] + .copy() + .astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + self.timesteps = torch.from_numpy(timesteps).to(device) def _get_variance(self, t, predicted_variance=None, variance_type=None): diff --git a/src/diffusers/schedulers/scheduling_ddpm_parallel.py b/src/diffusers/schedulers/scheduling_ddpm_parallel.py index 2719d90b9314..e4d858efde8f 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_parallel.py +++ b/src/diffusers/schedulers/scheduling_ddpm_parallel.py @@ -116,6 +116,13 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin): (https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`. sample_max_value (`float`, default `1.0`): the threshold value for dynamic thresholding. Valid only when `thresholding=True`. + timestep_spacing (`str`, default `"leading"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -138,6 +145,8 @@ def __init__( dynamic_thresholding_ratio: float = 0.995, clip_sample_range: float = 1.0, sample_max_value: float = 1.0, + timestep_spacing: str = "leading", + steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -234,11 +243,33 @@ def set_timesteps( ) self.num_inference_steps = num_inference_steps - - step_ratio = self.config.num_train_timesteps // self.num_inference_steps - timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) self.custom_timesteps = False + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps) + .round()[::-1] + .copy() + .astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + self.timesteps = torch.from_numpy(timesteps).to(device) # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._get_variance diff --git a/src/diffusers/schedulers/scheduling_deis_multistep.py b/src/diffusers/schedulers/scheduling_deis_multistep.py index 56c362018c18..c504fb19231a 100644 --- a/src/diffusers/schedulers/scheduling_deis_multistep.py +++ b/src/diffusers/schedulers/scheduling_deis_multistep.py @@ -107,6 +107,13 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. + timestep_spacing (`str`, default `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -129,6 +136,8 @@ def __init__( solver_type: str = "logrho", lower_order_final: bool = True, use_karras_sigmas: Optional[bool] = False, + timestep_spacing: str = "linspace", + steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -185,12 +194,30 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic device (`str` or `torch.device`, optional): the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ - timesteps = ( - np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) - .round()[::-1][:-1] - .copy() - .astype(np.int64) - ) + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) + .round()[::-1][:-1] + .copy() + .astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // (num_inference_steps + 1) + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.arange(self.config.num_train_timesteps, 0, -step_ratio).round().copy().astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) if self.config.use_karras_sigmas: diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index d7c29d5488a5..528b7b838b1c 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -134,6 +134,13 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on diffusion ODEs. + timestep_spacing (`str`, default `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -158,6 +165,8 @@ def __init__( use_karras_sigmas: Optional[bool] = False, lambda_min_clipped: float = -float("inf"), variance_type: Optional[str] = None, + timestep_spacing: str = "linspace", + steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -217,12 +226,29 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc # Clipping the minimum of all lambda(t) for numerical stability. # This is critical for cosine (squaredcos_cap_v2) noise schedule. clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped) - timesteps = ( - np.linspace(0, self.config.num_train_timesteps - 1 - clipped_idx, num_inference_steps + 1) - .round()[::-1][:-1] - .copy() - .astype(np.int64) - ) + last_timestep = ((self.config.num_train_timesteps - clipped_idx).numpy()).item() + + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, last_timestep - 1, num_inference_steps + 1).round()[::-1][:-1].copy().astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = last_timestep // (num_inference_steps + 1) + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) if self.config.use_karras_sigmas: diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py index ae9229981152..da8b71788b75 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py @@ -133,6 +133,13 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. noise_sampler_seed (`int`, *optional*, defaults to `None`): The random seed to use for the noise sampler. If `None`, a random seed will be generated. + timestep_spacing (`str`, default `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -149,6 +156,8 @@ def __init__( prediction_type: str = "epsilon", use_karras_sigmas: Optional[bool] = False, noise_sampler_seed: Optional[int] = None, + timestep_spacing: str = "linspace", + steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -187,6 +196,14 @@ def index_for_timestep(self, timestep, schedule_timesteps=None): pos = 0 return indices[pos].item() + @property + def init_noise_sigma(self): + # standard deviation of the initial noise distribution + if self.config.timestep_spacing in ["linspace", "trailing"]: + return self.sigmas.max() + + return (self.sigmas.max() ** 2 + 1) ** 0.5 + def scale_model_input( self, sample: torch.FloatTensor, @@ -226,7 +243,25 @@ def set_timesteps( num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps - timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + elif self.config.timestep_spacing == "leading": + step_ratio = num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(float) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) log_sigmas = np.log(sigmas) @@ -242,9 +277,6 @@ def set_timesteps( sigmas = torch.from_numpy(sigmas).to(device=device) self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]]) - # standard deviation of the initial noise distribution - self.init_noise_sigma = self.sigmas.max() - timesteps = torch.from_numpy(timesteps) second_order_timesteps = torch.from_numpy(second_order_timesteps) timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)]) diff --git a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py index 6b08e9bfc207..6b8c2f1a8a28 100644 --- a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py @@ -99,7 +99,13 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) - + timestep_spacing (`str`, default `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -114,6 +120,8 @@ def __init__( beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, prediction_type: str = "epsilon", + timestep_spacing: str = "linspace", + steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -137,15 +145,20 @@ def __init__( sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) - # standard deviation of the initial noise distribution - self.init_noise_sigma = self.sigmas.max() - # setable values self.num_inference_steps = None timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() self.timesteps = torch.from_numpy(timesteps) self.is_scale_input_called = False + @property + def init_noise_sigma(self): + # standard deviation of the initial noise distribution + if self.config.timestep_spacing in ["linspace", "trailing"]: + return self.sigmas.max() + + return (self.sigmas.max() ** 2 + 1) ** 0.5 + def scale_model_input( self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] ) -> torch.FloatTensor: @@ -179,7 +192,28 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic """ self.num_inference_steps = num_inference_steps - timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[ + ::-1 + ].copy() + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(float) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index 7237128cbf07..fc52c50ebc7f 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -107,6 +107,13 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. + timestep_spacing (`str`, default `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -123,6 +130,8 @@ def __init__( prediction_type: str = "epsilon", interpolation_type: str = "linear", use_karras_sigmas: Optional[bool] = False, + timestep_spacing: str = "linspace", + steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -146,9 +155,6 @@ def __init__( sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) - # standard deviation of the initial noise distribution - self.init_noise_sigma = self.sigmas.max() - # setable values self.num_inference_steps = None timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() @@ -156,6 +162,14 @@ def __init__( self.is_scale_input_called = False self.use_karras_sigmas = use_karras_sigmas + @property + def init_noise_sigma(self): + # standard deviation of the initial noise distribution + if self.config.timestep_spacing in ["linspace", "trailing"]: + return self.sigmas.max() + + return (self.sigmas.max() ** 2 + 1) ** 0.5 + def scale_model_input( self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] ) -> torch.FloatTensor: @@ -191,7 +205,28 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic """ self.num_inference_steps = num_inference_steps - timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[ + ::-1 + ].copy() + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(float) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) log_sigmas = np.log(sigmas) diff --git a/src/diffusers/schedulers/scheduling_heun_discrete.py b/src/diffusers/schedulers/scheduling_heun_discrete.py index 100e2012ea20..28f29067a544 100644 --- a/src/diffusers/schedulers/scheduling_heun_discrete.py +++ b/src/diffusers/schedulers/scheduling_heun_discrete.py @@ -78,6 +78,13 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. + timestep_spacing (`str`, default `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -93,6 +100,8 @@ def __init__( trained_betas: Optional[Union[np.ndarray, List[float]]] = None, prediction_type: str = "epsilon", use_karras_sigmas: Optional[bool] = False, + timestep_spacing: str = "linspace", + steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -128,6 +137,14 @@ def index_for_timestep(self, timestep, schedule_timesteps=None): pos = 0 return indices[pos].item() + @property + def init_noise_sigma(self): + # standard deviation of the initial noise distribution + if self.config.timestep_spacing in ["linspace", "trailing"]: + return self.sigmas.max() + + return (self.sigmas.max() ** 2 + 1) ** 0.5 + def scale_model_input( self, sample: torch.FloatTensor, @@ -166,7 +183,25 @@ def set_timesteps( num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps - timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + elif self.config.timestep_spacing == "leading": + step_ratio = num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(float) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) log_sigmas = np.log(sigmas) @@ -180,9 +215,6 @@ def set_timesteps( sigmas = torch.from_numpy(sigmas).to(device=device) self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]]) - # standard deviation of the initial noise distribution - self.init_noise_sigma = self.sigmas.max() - timesteps = torch.from_numpy(timesteps) timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)]) diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py index 2fa0431e1292..d4a35ab82502 100644 --- a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py @@ -78,6 +78,13 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) + timestep_spacing (`str`, default `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -92,6 +99,8 @@ def __init__( beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, prediction_type: str = "epsilon", + timestep_spacing: str = "linspace", + steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -127,6 +136,14 @@ def index_for_timestep(self, timestep, schedule_timesteps=None): pos = 0 return indices[pos].item() + @property + def init_noise_sigma(self): + # standard deviation of the initial noise distribution + if self.config.timestep_spacing in ["linspace", "trailing"]: + return self.sigmas.max() + + return (self.sigmas.max() ** 2 + 1) ** 0.5 + def scale_model_input( self, sample: torch.FloatTensor, @@ -169,7 +186,25 @@ def set_timesteps( num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps - timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + elif self.config.timestep_spacing == "leading": + step_ratio = num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(float) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) self.log_sigmas = torch.from_numpy(np.log(sigmas)).to(device) @@ -197,9 +232,6 @@ def set_timesteps( self.sigmas_up = torch.cat([sigmas_up[:1], sigmas_up[1:].repeat_interleave(2), sigmas_up[-1:]]) self.sigmas_down = torch.cat([sigmas_down[:1], sigmas_down[1:].repeat_interleave(2), sigmas_down[-1:]]) - # standard deviation of the initial noise distribution - self.init_noise_sigma = self.sigmas.max() - if str(device).startswith("mps"): # mps does not support float64 timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32) diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py index bb80c4a54bfe..39079fde10d2 100644 --- a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py @@ -77,6 +77,13 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) + timestep_spacing (`str`, default `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -91,6 +98,8 @@ def __init__( beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, prediction_type: str = "epsilon", + timestep_spacing: str = "linspace", + steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -126,6 +135,14 @@ def index_for_timestep(self, timestep, schedule_timesteps=None): pos = 0 return indices[pos].item() + @property + def init_noise_sigma(self): + # standard deviation of the initial noise distribution + if self.config.timestep_spacing in ["linspace", "trailing"]: + return self.sigmas.max() + + return (self.sigmas.max() ** 2 + 1) ** 0.5 + def scale_model_input( self, sample: torch.FloatTensor, @@ -168,7 +185,25 @@ def set_timesteps( num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps - timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + elif self.config.timestep_spacing == "leading": + step_ratio = num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(float) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) self.log_sigmas = torch.from_numpy(np.log(sigmas)).to(device) @@ -185,9 +220,6 @@ def set_timesteps( [sigmas_interpol[:1], sigmas_interpol[1:].repeat_interleave(2), sigmas_interpol[-1:]] ) - # standard deviation of the initial noise distribution - self.init_noise_sigma = self.sigmas.max() - if str(device).startswith("mps"): # mps does not support float64 timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32) diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 0656475c3093..1256660b843c 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -102,6 +102,13 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) + timestep_spacing (`str`, default `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -117,6 +124,8 @@ def __init__( trained_betas: Optional[Union[np.ndarray, List[float]]] = None, use_karras_sigmas: Optional[bool] = False, prediction_type: str = "epsilon", + timestep_spacing: str = "linspace", + steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -140,9 +149,6 @@ def __init__( sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) - # standard deviation of the initial noise distribution - self.init_noise_sigma = self.sigmas.max() - # setable values self.num_inference_steps = None self.use_karras_sigmas = use_karras_sigmas @@ -150,6 +156,14 @@ def __init__( self.derivatives = [] self.is_scale_input_called = False + @property + def init_noise_sigma(self): + # standard deviation of the initial noise distribution + if self.config.timestep_spacing in ["linspace", "trailing"]: + return self.sigmas.max() + + return (self.sigmas.max() ** 2 + 1) ** 0.5 + def scale_model_input( self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] ) -> torch.FloatTensor: @@ -205,7 +219,27 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic """ self.num_inference_steps = num_inference_steps - timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[ + ::-1 + ].copy() + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(float) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) log_sigmas = np.log(sigmas) diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index 01c02a21bbfc..70ee1301129c 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -85,11 +85,13 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): prediction_type (`str`, default `epsilon`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) + timestep_spacing (`str`, default `"leading"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. steps_offset (`int`, default `0`): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. - """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -106,6 +108,7 @@ def __init__( skip_prk_steps: bool = False, set_alpha_to_one: bool = False, prediction_type: str = "epsilon", + timestep_spacing: str = "leading", steps_offset: int = 0, ): if trained_betas is not None: @@ -159,11 +162,29 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic """ self.num_inference_steps = num_inference_steps - step_ratio = self.config.num_train_timesteps // self.num_inference_steps - # creates integer timesteps by multiplying by ratio - # casting to int to avoid issues when num_inference_step is power of 3 - self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round() - self._timesteps += self.config.steps_offset + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + self._timesteps = ( + np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps).round().astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round() + self._timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + self._timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio))[::-1].astype( + np.int64 + ) + self._timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) if self.config.skip_prk_steps: # for some models like stable diffusion the prk steps can/should be skipped to diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py index 7233258a4766..3caa01a58562 100644 --- a/src/diffusers/schedulers/scheduling_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -121,6 +121,13 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. + timestep_spacing (`str`, default `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -145,6 +152,8 @@ def __init__( disable_corrector: List[int] = [], solver_p: SchedulerMixin = None, use_karras_sigmas: Optional[bool] = False, + timestep_spacing: str = "linspace", + steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -173,7 +182,7 @@ def __init__( if solver_type not in ["bh1", "bh2"]: if solver_type in ["midpoint", "heun", "logrho"]: - self.register_to_config(solver_type="bh1") + self.register_to_config(solver_type="bh2") else: raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") @@ -199,12 +208,30 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic device (`str` or `torch.device`, optional): the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ - timesteps = ( - np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) - .round()[::-1][:-1] - .copy() - .astype(np.int64) - ) + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) + .round()[::-1][:-1] + .copy() + .astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // (num_inference_steps + 1) + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.arange(self.config.num_train_timesteps, 0, -step_ratio).round().copy().astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) if self.config.use_karras_sigmas: diff --git a/tests/others/test_config.py b/tests/others/test_config.py index a29190c199ca..d1f8a6e054d4 100644 --- a/tests/others/test_config.py +++ b/tests/others/test_config.py @@ -75,6 +75,22 @@ def __init__( pass +class SampleObject4(ConfigMixin): + config_name = "config.json" + + @register_to_config + def __init__( + self, + a=2, + b=5, + c=(2, 5), + d="for diffusion", + e=[1, 5], + f=[5, 4], + ): + pass + + class ConfigTester(unittest.TestCase): def test_load_not_from_mixin(self): with self.assertRaises(ValueError): @@ -137,6 +153,7 @@ def test_save_load(self): assert config.pop("c") == (2, 5) # instantiated as tuple assert new_config.pop("c") == [2, 5] # saved & loaded as list because of json + config.pop("_use_default_values") assert config == new_config def test_load_ddim_from_pndm(self): @@ -233,3 +250,39 @@ def test_load_dpmsolver(self): assert dpm.__class__ == DPMSolverMultistepScheduler # no warning should be thrown assert cap_logger.out == "" + + def test_use_default_values(self): + # let's first save a config that should be in the form + # a=2, + # b=5, + # c=(2, 5), + # d="for diffusion", + # e=[1, 3], + + config = SampleObject() + + config_dict = {k: v for k, v in config.config.items() if not k.startswith("_")} + + # make sure that default config has all keys in `_use_default_values` + assert set(config_dict.keys()) == config.config._use_default_values + + with tempfile.TemporaryDirectory() as tmpdirname: + config.save_config(tmpdirname) + + # now loading it with SampleObject2 should put f into `_use_default_values` + config = SampleObject2.from_config(tmpdirname) + + assert "f" in config._use_default_values + assert config.f == [1, 3] + + # now loading the config, should **NOT** use [1, 3] for `f`, but the default [1, 4] value + # **BECAUSE** it is part of `config._use_default_values` + new_config = SampleObject4.from_config(config.config) + assert new_config.f == [5, 4] + + config.config._use_default_values.pop() + new_config_2 = SampleObject4.from_config(config.config) + assert new_config_2.f == [1, 3] + + # Nevertheless "e" should still be correctly loaded to [1, 3] from SampleObject2 instead of defaulting to [1, 5] + assert new_config_2.e == [1, 3] diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py index 32541c980a15..080bd0091f4f 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py @@ -186,7 +186,7 @@ def test_stable_diffusion_panorama_euler(self): assert image.shape == (1, 64, 64, 3) - expected_slice = np.array([0.4886, 0.5586, 0.4476, 0.5053, 0.6013, 0.4737, 0.5538, 0.5100, 0.4927]) + expected_slice = np.array([0.4024, 0.6510, 0.4901, 0.5378, 0.5813, 0.5622, 0.4795, 0.4467, 0.4952]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 diff --git a/tests/schedulers/test_scheduler_euler.py b/tests/schedulers/test_scheduler_euler.py index aa46ef31885a..0c3b065161db 100644 --- a/tests/schedulers/test_scheduler_euler.py +++ b/tests/schedulers/test_scheduler_euler.py @@ -101,7 +101,7 @@ def test_full_loop_device(self): generator = torch.manual_seed(0) model = self.dummy_model() - sample = self.dummy_sample_deter * scheduler.init_noise_sigma + sample = self.dummy_sample_deter * scheduler.init_noise_sigma.cpu() sample = sample.to(torch_device) for t in scheduler.timesteps: @@ -128,7 +128,7 @@ def test_full_loop_device_karras_sigmas(self): generator = torch.manual_seed(0) model = self.dummy_model() - sample = self.dummy_sample_deter * scheduler.init_noise_sigma + sample = self.dummy_sample_deter * scheduler.init_noise_sigma.cpu() sample = sample.to(torch_device) for t in scheduler.timesteps: diff --git a/tests/schedulers/test_scheduler_euler_ancestral.py b/tests/schedulers/test_scheduler_euler_ancestral.py index 5fa36be6bc64..9866bd12d6af 100644 --- a/tests/schedulers/test_scheduler_euler_ancestral.py +++ b/tests/schedulers/test_scheduler_euler_ancestral.py @@ -47,7 +47,7 @@ def test_full_loop_no_noise(self): generator = torch.manual_seed(0) model = self.dummy_model() - sample = self.dummy_sample_deter * scheduler.init_noise_sigma + sample = self.dummy_sample_deter * scheduler.init_noise_sigma.cpu() sample = sample.to(torch_device) for i, t in enumerate(scheduler.timesteps): @@ -100,7 +100,7 @@ def test_full_loop_device(self): generator = torch.manual_seed(0) model = self.dummy_model() - sample = self.dummy_sample_deter * scheduler.init_noise_sigma + sample = self.dummy_sample_deter * scheduler.init_noise_sigma.cpu() sample = sample.to(torch_device) for t in scheduler.timesteps: diff --git a/tests/schedulers/test_scheduler_lms.py b/tests/schedulers/test_scheduler_lms.py index 2682886a788d..1e0a8212354d 100644 --- a/tests/schedulers/test_scheduler_lms.py +++ b/tests/schedulers/test_scheduler_lms.py @@ -97,7 +97,7 @@ def test_full_loop_device(self): scheduler.set_timesteps(self.num_inference_steps, device=torch_device) model = self.dummy_model() - sample = self.dummy_sample_deter * scheduler.init_noise_sigma + sample = self.dummy_sample_deter * scheduler.init_noise_sigma.cpu() sample = sample.to(torch_device) for i, t in enumerate(scheduler.timesteps): diff --git a/tests/schedulers/test_scheduler_unipc.py b/tests/schedulers/test_scheduler_unipc.py index 62cffc67388c..171ee85be1d3 100644 --- a/tests/schedulers/test_scheduler_unipc.py +++ b/tests/schedulers/test_scheduler_unipc.py @@ -23,7 +23,7 @@ def get_scheduler_config(self, **kwargs): "beta_end": 0.02, "beta_schedule": "linear", "solver_order": 2, - "solver_type": "bh1", + "solver_type": "bh2", } config.update(**kwargs) @@ -144,7 +144,7 @@ def test_switch(self): sample = self.full_loop(scheduler=scheduler) result_mean = torch.mean(torch.abs(sample)) - assert abs(result_mean.item() - 0.2521) < 1e-3 + assert abs(result_mean.item() - 0.2464) < 1e-3 scheduler = DPMSolverSinglestepScheduler.from_config(scheduler.config) scheduler = DEISMultistepScheduler.from_config(scheduler.config) @@ -154,7 +154,7 @@ def test_switch(self): sample = self.full_loop(scheduler=scheduler) result_mean = torch.mean(torch.abs(sample)) - assert abs(result_mean.item() - 0.2521) < 1e-3 + assert abs(result_mean.item() - 0.2464) < 1e-3 def test_timesteps(self): for timesteps in [25, 50, 100, 999, 1000]: @@ -206,13 +206,13 @@ def test_full_loop_no_noise(self): sample = self.full_loop() result_mean = torch.mean(torch.abs(sample)) - assert abs(result_mean.item() - 0.2521) < 1e-3 + assert abs(result_mean.item() - 0.2464) < 1e-3 def test_full_loop_with_v_prediction(self): sample = self.full_loop(prediction_type="v_prediction") result_mean = torch.mean(torch.abs(sample)) - assert abs(result_mean.item() - 0.1096) < 1e-3 + assert abs(result_mean.item() - 0.1014) < 1e-3 def test_fp16_support(self): scheduler_class = self.scheduler_classes[0] diff --git a/tests/schedulers/test_schedulers.py b/tests/schedulers/test_schedulers.py index a2d065f388bd..d1ae333c0cd2 100755 --- a/tests/schedulers/test_schedulers.py +++ b/tests/schedulers/test_schedulers.py @@ -24,10 +24,14 @@ import diffusers from diffusers import ( + DDIMScheduler, + DEISMultistepScheduler, + DiffusionPipeline, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, IPNDMScheduler, LMSDiscreteScheduler, + UniPCMultistepScheduler, VQDiffusionScheduler, logging, ) @@ -202,6 +206,44 @@ def test_save_load_from_different_config_comp_schedulers(self): assert cap_logger_2.out == "{'f'} was not found in config. Values will be initialized to default values.\n" assert cap_logger_3.out == "{'f'} was not found in config. Values will be initialized to default values.\n" + def test_default_arguments_not_in_config(self): + pipe = DiffusionPipeline.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-pipe", torch_dtype=torch.float16 + ) + assert pipe.scheduler.__class__ == DDIMScheduler + + # Default for DDIMScheduler + assert pipe.scheduler.config.timestep_spacing == "leading" + + # Switch to a different one, verify we use the default for that class + pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config) + assert pipe.scheduler.config.timestep_spacing == "linspace" + + # Override with kwargs + pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") + assert pipe.scheduler.config.timestep_spacing == "trailing" + + # Verify overridden kwargs stick + pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config) + assert pipe.scheduler.config.timestep_spacing == "trailing" + + # And stick + pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config) + assert pipe.scheduler.config.timestep_spacing == "trailing" + + def test_default_solver_type_after_switch(self): + pipe = DiffusionPipeline.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-pipe", torch_dtype=torch.float16 + ) + assert pipe.scheduler.__class__ == DDIMScheduler + + pipe.scheduler = DEISMultistepScheduler.from_config(pipe.scheduler.config) + assert pipe.scheduler.config.solver_type == "logrho" + + # Switch to UniPC, verify the solver is the default + pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) + assert pipe.scheduler.config.solver_type == "bh2" + class SchedulerCommonTest(unittest.TestCase): scheduler_classes = () @@ -414,7 +456,11 @@ def test_from_pretrained(self): scheduler.save_pretrained(tmpdirname) new_scheduler = scheduler_class.from_pretrained(tmpdirname) - assert scheduler.config == new_scheduler.config + # `_use_default_values` should not exist for just saved & loaded scheduler + scheduler_config = dict(scheduler.config) + del scheduler_config["_use_default_values"] + + assert scheduler_config == new_scheduler.config def test_step_shape(self): kwargs = dict(self.forward_default_kwargs)