Skip to content

Separate Sigma Schedule #10146

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 21 commits into from
Closed
28 changes: 28 additions & 0 deletions src/diffusers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,34 @@ def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_un
deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False)
config, kwargs = cls.load_config(pretrained_model_name_or_path=config, return_unused_kwargs=True, **kwargs)

# Handle old scheduler configs
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will need to be robust and probably kept for a while unless we can find a way to mass update configs on the Hub. It's working with some scheduler configs already, FlowMatch vs Beta is detected with shift and beta_schedule, I've already found an edge case in SANA's config because we integrated those scheduler changes into DPM so it has beta_schedule and no shift (it was called flow_shift instead).

if "Scheduler" in cls.__name__ and "schedule_config" not in config:
prediction_type = config.pop("prediction_type", None)
_class_name = config.pop("_class_name", None)
_diffusers_version = config.pop("_diffusers_version", None)
use_karras_sigmas = config.pop("use_karras_sigmas", None)
use_exponential_sigmas = config.pop("use_exponential_sigmas", None)
use_beta_sigmas = config.pop("use_beta_sigmas", None)
if use_karras_sigmas:
sigma_schedule_config = {"class_name": "KarrasSigmas"}
elif use_exponential_sigmas:
sigma_schedule_config = {"class_name": "ExponentialSigmas"}
elif use_beta_sigmas:
sigma_schedule_config = {"class_name": "BetaSigmas"}
else:
sigma_schedule_config = {}
if "beta_schedule" in config:
config.update({"class_name": "BetaSchedule"})
elif "shift" in config:
config.update({"class_name": "FlowMatchSchedule"})
config = {
"_class_name": _class_name,
"_diffusers_version": _diffusers_version,
"prediction_type": prediction_type,
"schedule_config": config,
"sigma_schedule_config": sigma_schedule_config,
}

init_dict, unused_kwargs, hidden_dict = cls.extract_init_dict(config, **kwargs)

# Allow dtype to be specified on initialization
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/pipelines/flux/pipeline_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import inspect
from typing import Any, Callable, Dict, List, Optional, Union

import numpy as np
import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast

Expand Down Expand Up @@ -699,7 +698,8 @@ def __call__(
)

# 5. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
if self.scheduler.schedule.__class__.__name__ != "FlowMatchFlux":
self.scheduler._schedule.set_base_schedule("FlowMatchFlux")
image_seq_len = latents.shape[1]
mu = calculate_shift(
image_seq_len,
Expand Down
8 changes: 3 additions & 5 deletions src/diffusers/pipelines/mochi/pipeline_mochi.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import inspect
from typing import Any, Callable, Dict, List, Optional, Union

import numpy as np
import torch
from transformers import T5EncoderModel, T5TokenizerFast

Expand Down Expand Up @@ -495,6 +494,7 @@ def __call__(
num_frames: int = 19,
num_inference_steps: int = 64,
timesteps: List[int] = None,
sigmas: List[float] = None,
guidance_scale: float = 4.5,
num_videos_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
Expand Down Expand Up @@ -652,10 +652,8 @@ def __call__(
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)

# 5. Prepare timestep
# from https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77
threshold_noise = 0.025
sigmas = linear_quadratic_schedule(num_inference_steps, threshold_noise)
sigmas = np.array(sigmas)
if self.scheduler.schedule.__class__.__name__ != "FlowMatchLinearQuadratic":
self.scheduler._schedule.set_base_schedule("FlowMatchLinearQuadratic")

timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
Expand Down
Empty file.
252 changes: 252 additions & 0 deletions src/diffusers/schedulers/schedules/beta_schedule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
from typing import List, Optional, Union

import math
import numpy as np
import torch

from ...configuration_utils import ConfigMixin, register_to_config
from ..sigmas.beta_sigmas import BetaSigmas
from ..sigmas.exponential_sigmas import ExponentialSigmas
from ..sigmas.karras_sigmas import KarrasSigmas

def betas_for_alpha_bar(
num_diffusion_timesteps,
max_beta=0.999,
alpha_transform_type="cosine",
):
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].

Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.


Args:
num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities.
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
Choose from `cosine` or `exp`

Returns:
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
"""
if alpha_transform_type == "cosine":

def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2

elif alpha_transform_type == "exp":

def alpha_bar_fn(t):
return math.exp(t * -12.0)

else:
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")

betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
return torch.tensor(betas, dtype=torch.float32)

def rescale_zero_terminal_snr(betas):
"""
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)


Args:
betas (`torch.Tensor`):
the betas that the scheduler is being initialized with.

Returns:
`torch.Tensor`: rescaled betas with zero terminal SNR
"""
# Convert betas to alphas_bar_sqrt
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_bar_sqrt = alphas_cumprod.sqrt()

# Store old values.
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()

# Shift so the last timestep is zero.
alphas_bar_sqrt -= alphas_bar_sqrt_T

# Scale so the first timestep is back to the old value.
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)

# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
alphas = torch.cat([alphas_bar[0:1], alphas])
betas = 1 - alphas

return betas


class BetaSchedule:

scale_model_input = True

def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
rescale_betas_zero_snr: bool = False,
interpolation_type: str = "linear",
timestep_spacing: str = "linspace",
timestep_type: str = "discrete", # can be "discrete" or "continuous"
steps_offset: int = 0,
sigma_min: Optional[float] = None,
sigma_max: Optional[float] = None,
final_sigmas_type: str = "zero", # can be "zero" or "sigma_min"
**kwargs,
):
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
else:
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")

if rescale_betas_zero_snr:
self.betas = rescale_zero_terminal_snr(self.betas)

self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)

if rescale_betas_zero_snr:
# Close to 0 without being 0 so first sigma is not inf
# FP16 smallest positive subnormal works well here
self.alphas_cumprod[-1] = 2**-24

self.num_train_timesteps = num_train_timesteps
self.beta_start = beta_start
self.beta_end = beta_end
self.beta_schedule = beta_schedule
self.trained_betas = trained_betas
self.rescale_betas_zero_snr = rescale_betas_zero_snr
self.interpolation_type = interpolation_type
self.timestep_spacing = timestep_spacing
self.timestep_type = timestep_type
self.steps_offset = steps_offset
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.final_sigmas_type = final_sigmas_type

def _sigma_to_t(self, sigma, log_sigmas):
# get log sigma
log_sigma = np.log(np.maximum(sigma, 1e-10))

# get distribution
dists = log_sigma - log_sigmas[:, np.newaxis]

# get sigmas range
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
high_idx = low_idx + 1

low = log_sigmas[low_idx]
high = log_sigmas[high_idx]

# interpolate sigmas
w = (low - log_sigma) / (low - high)
w = np.clip(w, 0, 1)

# transform interpolation to time range
t = (1 - w) * low_idx + w * high_idx
t = t.reshape(sigma.shape)
return t

def __call__(
self,
num_inference_steps: int = None,
device: Union[str, torch.device] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
sigma_schedule: Optional[Union[KarrasSigmas, ExponentialSigmas, BetaSigmas]] = None,
**kwargs,
):
if sigmas is not None:
log_sigmas = np.log(np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5))
sigmas = np.array(sigmas).astype(np.float32)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas[:-1]])

else:
if timesteps is not None:
timesteps = np.array(timesteps).astype(np.float32)
else:
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
if self.timestep_spacing == "linspace":
timesteps = np.linspace(
0, self.num_train_timesteps - 1, num_inference_steps, dtype=np.float32
)[::-1].copy()
elif self.timestep_spacing == "leading":
step_ratio = self.num_train_timesteps // num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (
(np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
)
timesteps += self.steps_offset
elif self.timestep_spacing == "trailing":
step_ratio = self.num_train_timesteps / num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (
(np.arange(self.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
)
timesteps -= 1
else:
raise ValueError(
f"{self.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
)

sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
log_sigmas = np.log(sigmas)
if self.interpolation_type == "linear":
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
elif self.interpolation_type == "log_linear":
sigmas = torch.linspace(np.log(sigmas[-1]), np.log(sigmas[0]), num_inference_steps + 1).exp().numpy()
else:
raise ValueError(
f"{self.interpolation_type} is not implemented. Please specify interpolation_type to either"
" 'linear' or 'log_linear'"
)

if sigma_schedule is not None:
sigmas = sigma_schedule(sigmas)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])

if self.final_sigmas_type == "sigma_min":
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
elif self.final_sigmas_type == "zero":
sigma_last = 0
else:
raise ValueError(
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.final_sigmas_type}"
)

sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)

sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)

# TODO: Support the full EDM scalings for all prediction types and timestep types
if self.timestep_type == "continuous" and self.prediction_type == "v_prediction":
timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas[:-1]]).to(device=device)
else:
timesteps = torch.from_numpy(timesteps.astype(np.float32)).to(device=device)

return sigmas, timesteps
Loading
Loading