Skip to content

Commit 4da42c2

Browse files
hlkysayakpaul
authored andcommitted
Add beta, exponential and karras sigmas to FlowMatchEulerDiscreteScheduler (#10001)
Add beta, exponential and karras sigmas to FlowMatchEuler
1 parent c5f6fb9 commit 4da42c2

File tree

1 file changed

+105
-3
lines changed

1 file changed

+105
-3
lines changed

src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py

+105-3
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,13 @@
2020
import torch
2121

2222
from ..configuration_utils import ConfigMixin, register_to_config
23-
from ..utils import BaseOutput, logging
23+
from ..utils import BaseOutput, is_scipy_available, logging
2424
from .scheduling_utils import SchedulerMixin
2525

2626

27+
if is_scipy_available():
28+
import scipy.stats
29+
2730
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
2831

2932

@@ -72,7 +75,16 @@ def __init__(
7275
base_image_seq_len: Optional[int] = 256,
7376
max_image_seq_len: Optional[int] = 4096,
7477
invert_sigmas: bool = False,
78+
use_karras_sigmas: Optional[bool] = False,
79+
use_exponential_sigmas: Optional[bool] = False,
80+
use_beta_sigmas: Optional[bool] = False,
7581
):
82+
if self.config.use_beta_sigmas and not is_scipy_available():
83+
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
84+
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
85+
raise ValueError(
86+
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
87+
)
7688
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
7789
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
7890

@@ -185,23 +197,33 @@ def set_timesteps(
185197
device (`str` or `torch.device`, *optional*):
186198
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
187199
"""
188-
189200
if self.config.use_dynamic_shifting and mu is None:
190201
raise ValueError(" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`")
191202

192203
if sigmas is None:
193-
self.num_inference_steps = num_inference_steps
194204
timesteps = np.linspace(
195205
self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
196206
)
197207

198208
sigmas = timesteps / self.config.num_train_timesteps
209+
else:
210+
num_inference_steps = len(sigmas)
211+
self.num_inference_steps = num_inference_steps
199212

200213
if self.config.use_dynamic_shifting:
201214
sigmas = self.time_shift(mu, 1.0, sigmas)
202215
else:
203216
sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
204217

218+
if self.config.use_karras_sigmas:
219+
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
220+
221+
elif self.config.use_exponential_sigmas:
222+
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
223+
224+
elif self.config.use_beta_sigmas:
225+
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
226+
205227
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
206228
timesteps = sigmas * self.config.num_train_timesteps
207229

@@ -314,5 +336,85 @@ def step(
314336

315337
return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
316338

339+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
340+
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
341+
"""Constructs the noise schedule of Karras et al. (2022)."""
342+
343+
# Hack to make sure that other schedulers which copy this function don't break
344+
# TODO: Add this logic to the other schedulers
345+
if hasattr(self.config, "sigma_min"):
346+
sigma_min = self.config.sigma_min
347+
else:
348+
sigma_min = None
349+
350+
if hasattr(self.config, "sigma_max"):
351+
sigma_max = self.config.sigma_max
352+
else:
353+
sigma_max = None
354+
355+
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
356+
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
357+
358+
rho = 7.0 # 7.0 is the value used in the paper
359+
ramp = np.linspace(0, 1, num_inference_steps)
360+
min_inv_rho = sigma_min ** (1 / rho)
361+
max_inv_rho = sigma_max ** (1 / rho)
362+
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
363+
return sigmas
364+
365+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
366+
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
367+
"""Constructs an exponential noise schedule."""
368+
369+
# Hack to make sure that other schedulers which copy this function don't break
370+
# TODO: Add this logic to the other schedulers
371+
if hasattr(self.config, "sigma_min"):
372+
sigma_min = self.config.sigma_min
373+
else:
374+
sigma_min = None
375+
376+
if hasattr(self.config, "sigma_max"):
377+
sigma_max = self.config.sigma_max
378+
else:
379+
sigma_max = None
380+
381+
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
382+
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
383+
384+
sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
385+
return sigmas
386+
387+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
388+
def _convert_to_beta(
389+
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
390+
) -> torch.Tensor:
391+
"""From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
392+
393+
# Hack to make sure that other schedulers which copy this function don't break
394+
# TODO: Add this logic to the other schedulers
395+
if hasattr(self.config, "sigma_min"):
396+
sigma_min = self.config.sigma_min
397+
else:
398+
sigma_min = None
399+
400+
if hasattr(self.config, "sigma_max"):
401+
sigma_max = self.config.sigma_max
402+
else:
403+
sigma_max = None
404+
405+
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
406+
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
407+
408+
sigmas = np.array(
409+
[
410+
sigma_min + (ppf * (sigma_max - sigma_min))
411+
for ppf in [
412+
scipy.stats.beta.ppf(timestep, alpha, beta)
413+
for timestep in 1 - np.linspace(0, 1, num_inference_steps)
414+
]
415+
]
416+
)
417+
return sigmas
418+
317419
def __len__(self):
318420
return self.config.num_train_timesteps

0 commit comments

Comments
 (0)