Skip to content

Add timestep_spacing and steps_offset to schedulers #3947

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 30 commits into from
Jul 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
326172d
Add timestep_spacing to DDPM, LMSDiscrete, PNDM.
pcuenca Jul 3, 2023
9ac281b
Remove spurious line.
pcuenca Jul 3, 2023
77075af
More easy schedulers.
pcuenca Jul 3, 2023
5e3c4f7
Add `linspace` to DDIM
pcuenca Jul 3, 2023
c1e93d2
Noise sigma for `trailing`.
pcuenca Jul 3, 2023
cd4cc28
Add timestep_spacing to DEISMultistepScheduler.
pcuenca Jul 3, 2023
02cdd87
Fix: remove line used to debug.
pcuenca Jul 3, 2023
30e3a5d
Support timestep_spacing in DPMSolverMultistep, DPMSolverSDE, UniPC
pcuenca Jul 3, 2023
8375d0a
Fix: convert to numpy.
pcuenca Jul 3, 2023
eb60db4
Use sched. defaults when instantiating from_config
pcuenca Jul 3, 2023
81c5a1e
Apply suggestions from code review
pcuenca Jul 4, 2023
a38021f
Missing args in DPMSolverMultistep
pcuenca Jul 4, 2023
8f9c73c
Test: default args not in config
pcuenca Jul 4, 2023
d833276
Style
pcuenca Jul 4, 2023
3de3909
Fix scheduler name in test
pcuenca Jul 4, 2023
5080d7f
Remove duplicated entries
pcuenca Jul 4, 2023
083c8aa
Add test for solver_type
pcuenca Jul 4, 2023
bef7669
UniPC: use same default for solver_type
pcuenca Jul 4, 2023
21d7eea
do not save use default values
patrickvonplaten Jul 5, 2023
cb1109f
fix more
patrickvonplaten Jul 5, 2023
46c5c8f
fix all
patrickvonplaten Jul 5, 2023
dc8c70b
fix schedulers
patrickvonplaten Jul 5, 2023
6470353
fix more
patrickvonplaten Jul 5, 2023
72949cb
finish for real
patrickvonplaten Jul 5, 2023
45328ef
finish for real
patrickvonplaten Jul 5, 2023
8f4c443
flaky tests
patrickvonplaten Jul 5, 2023
cedcb1b
Update tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix…
patrickvonplaten Jul 5, 2023
a8bdf06
Default steps_offset to 0.
pcuenca Jul 5, 2023
f5a1d2c
Add missing docstrings
pcuenca Jul 5, 2023
910d609
Apply suggestions from code review
patrickvonplaten Jul 5, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion src/diffusers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,10 @@ def _get_init_keys(cls):

@classmethod
def extract_init_dict(cls, config_dict, **kwargs):
# Skip keys that were not present in the original config, so default __init__ values were used
used_defaults = config_dict.get("_use_default_values", [])
config_dict = {k: v for k, v in config_dict.items() if k not in used_defaults and k != "_use_default_values"}

# 0. Copy origin config dict
original_dict = dict(config_dict.items())

Expand Down Expand Up @@ -544,8 +548,9 @@ def to_json_saveable(value):
return value

config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
# Don't save "_ignore_files"
# Don't save "_ignore_files" or "_use_default_values"
config_dict.pop("_ignore_files", None)
config_dict.pop("_use_default_values", None)

return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"

Expand Down Expand Up @@ -599,6 +604,11 @@ def inner_init(self, *args, **kwargs):
if k not in ignore and k not in new_kwargs
}
)

# Take note of the parameters that were not present in the loaded config
if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0:
new_kwargs["_use_default_values"] = set(new_kwargs.keys()) - set(init_kwargs)

new_kwargs = {**config_init_kwargs, **new_kwargs}
getattr(self, "register_to_config")(**new_kwargs)
init(self, *args, **init_kwargs)
Expand Down Expand Up @@ -643,6 +653,10 @@ def init(self, *args, **kwargs):
name = fields[i].name
new_kwargs[name] = arg

# Take note of the parameters that were not present in the loaded config
if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0:
new_kwargs["_use_default_values"] = set(new_kwargs.keys()) - set(init_kwargs)

getattr(self, "register_to_config")(**new_kwargs)
original_init(self, *args, **kwargs)

Expand Down
11 changes: 9 additions & 2 deletions src/diffusers/schedulers/scheduling_ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,15 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic

self.num_inference_steps = num_inference_steps

# "leading" and "trailing" corresponds to annotation of Table 1. of https://arxiv.org/abs/2305.08891
if self.config.timestep_spacing == "leading":
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
if self.config.timestep_spacing == "linspace":
timesteps = (
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps)
.round()[::-1]
.copy()
.astype(np.int64)
)
elif self.config.timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
Expand Down
11 changes: 9 additions & 2 deletions src/diffusers/schedulers/scheduling_ddim_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,15 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic

self.num_inference_steps = num_inference_steps

# "leading" and "trailing" corresponds to annotation of Table 1. of https://arxiv.org/abs/2305.08891
if self.config.timestep_spacing == "leading":
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
if self.config.timestep_spacing == "linspace":
timesteps = (
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps)
.round()[::-1]
.copy()
.astype(np.int64)
)
elif self.config.timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
Expand Down
37 changes: 34 additions & 3 deletions src/diffusers/schedulers/scheduling_ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,13 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
(https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`.
sample_max_value (`float`, default `1.0`):
the threshold value for dynamic thresholding. Valid only when `thresholding=True`.
timestep_spacing (`str`, default `"leading"`):
The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information.
steps_offset (`int`, default `0`):
an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
"""

_compatibles = [e.name for e in KarrasDiffusionSchedulers]
Expand All @@ -134,6 +141,8 @@ def __init__(
dynamic_thresholding_ratio: float = 0.995,
clip_sample_range: float = 1.0,
sample_max_value: float = 1.0,
timestep_spacing: str = "leading",
steps_offset: int = 0,
):
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
Expand Down Expand Up @@ -228,11 +237,33 @@ def set_timesteps(
)

self.num_inference_steps = num_inference_steps

step_ratio = self.config.num_train_timesteps // self.num_inference_steps
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
self.custom_timesteps = False

# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
if self.config.timestep_spacing == "linspace":
timesteps = (
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps)
.round()[::-1]
.copy()
.astype(np.int64)
)
elif self.config.timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64)
timesteps -= 1
else:
raise ValueError(
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
)

self.timesteps = torch.from_numpy(timesteps).to(device)

def _get_variance(self, t, predicted_variance=None, variance_type=None):
Expand Down
37 changes: 34 additions & 3 deletions src/diffusers/schedulers/scheduling_ddpm_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,13 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
(https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`.
sample_max_value (`float`, default `1.0`):
the threshold value for dynamic thresholding. Valid only when `thresholding=True`.
timestep_spacing (`str`, default `"leading"`):
The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information.
steps_offset (`int`, default `0`):
an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
"""

_compatibles = [e.name for e in KarrasDiffusionSchedulers]
Expand All @@ -138,6 +145,8 @@ def __init__(
dynamic_thresholding_ratio: float = 0.995,
clip_sample_range: float = 1.0,
sample_max_value: float = 1.0,
timestep_spacing: str = "leading",
steps_offset: int = 0,
):
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
Expand Down Expand Up @@ -234,11 +243,33 @@ def set_timesteps(
)

self.num_inference_steps = num_inference_steps

step_ratio = self.config.num_train_timesteps // self.num_inference_steps
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
self.custom_timesteps = False

# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
if self.config.timestep_spacing == "linspace":
timesteps = (
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps)
.round()[::-1]
.copy()
.astype(np.int64)
)
elif self.config.timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64)
timesteps -= 1
else:
raise ValueError(
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
)

self.timesteps = torch.from_numpy(timesteps).to(device)

# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._get_variance
Expand Down
39 changes: 33 additions & 6 deletions src/diffusers/schedulers/scheduling_deis_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,13 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the
noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence
of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf.
timestep_spacing (`str`, default `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information.
steps_offset (`int`, default `0`):
an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
"""

_compatibles = [e.name for e in KarrasDiffusionSchedulers]
Expand All @@ -129,6 +136,8 @@ def __init__(
solver_type: str = "logrho",
lower_order_final: bool = True,
use_karras_sigmas: Optional[bool] = False,
timestep_spacing: str = "linspace",
steps_offset: int = 0,
):
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
Expand Down Expand Up @@ -185,12 +194,30 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
device (`str` or `torch.device`, optional):
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
timesteps = (
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)
.round()[::-1][:-1]
.copy()
.astype(np.int64)
)
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
if self.config.timestep_spacing == "linspace":
timesteps = (
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)
.round()[::-1][:-1]
.copy()
.astype(np.int64)
)
elif self.config.timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // (num_inference_steps + 1)
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64)
timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = np.arange(self.config.num_train_timesteps, 0, -step_ratio).round().copy().astype(np.int64)
timesteps -= 1
else:
raise ValueError(
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
)

sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
if self.config.use_karras_sigmas:
Expand Down
38 changes: 32 additions & 6 deletions src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,13 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the
Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on
diffusion ODEs.
timestep_spacing (`str`, default `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information.
steps_offset (`int`, default `0`):
an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
"""

_compatibles = [e.name for e in KarrasDiffusionSchedulers]
Expand All @@ -158,6 +165,8 @@ def __init__(
use_karras_sigmas: Optional[bool] = False,
lambda_min_clipped: float = -float("inf"),
variance_type: Optional[str] = None,
timestep_spacing: str = "linspace",
steps_offset: int = 0,
):
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
Expand Down Expand Up @@ -217,12 +226,29 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc
# Clipping the minimum of all lambda(t) for numerical stability.
# This is critical for cosine (squaredcos_cap_v2) noise schedule.
clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped)
timesteps = (
np.linspace(0, self.config.num_train_timesteps - 1 - clipped_idx, num_inference_steps + 1)
.round()[::-1][:-1]
.copy()
.astype(np.int64)
)
last_timestep = ((self.config.num_train_timesteps - clipped_idx).numpy()).item()

# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
if self.config.timestep_spacing == "linspace":
timesteps = (
np.linspace(0, last_timestep - 1, num_inference_steps + 1).round()[::-1][:-1].copy().astype(np.int64)
)
elif self.config.timestep_spacing == "leading":
step_ratio = last_timestep // (num_inference_steps + 1)
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64)
timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.int64)
timesteps -= 1
else:
raise ValueError(
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
)

sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
if self.config.use_karras_sigmas:
Expand Down
Loading