|
20 | 20 | import torch
|
21 | 21 |
|
22 | 22 | from ..configuration_utils import ConfigMixin, register_to_config
|
23 |
| -from ..utils import BaseOutput, logging |
| 23 | +from ..utils import BaseOutput, is_scipy_available, logging |
24 | 24 | from .scheduling_utils import SchedulerMixin
|
25 | 25 |
|
26 | 26 |
|
| 27 | +if is_scipy_available(): |
| 28 | + import scipy.stats |
| 29 | + |
27 | 30 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
28 | 31 |
|
29 | 32 |
|
@@ -72,7 +75,16 @@ def __init__(
|
72 | 75 | base_image_seq_len: Optional[int] = 256,
|
73 | 76 | max_image_seq_len: Optional[int] = 4096,
|
74 | 77 | invert_sigmas: bool = False,
|
| 78 | + use_karras_sigmas: Optional[bool] = False, |
| 79 | + use_exponential_sigmas: Optional[bool] = False, |
| 80 | + use_beta_sigmas: Optional[bool] = False, |
75 | 81 | ):
|
| 82 | + if self.config.use_beta_sigmas and not is_scipy_available(): |
| 83 | + raise ImportError("Make sure to install scipy if you want to use beta sigmas.") |
| 84 | + if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: |
| 85 | + raise ValueError( |
| 86 | + "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used." |
| 87 | + ) |
76 | 88 | timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
|
77 | 89 | timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
|
78 | 90 |
|
@@ -185,23 +197,33 @@ def set_timesteps(
|
185 | 197 | device (`str` or `torch.device`, *optional*):
|
186 | 198 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
187 | 199 | """
|
188 |
| - |
189 | 200 | if self.config.use_dynamic_shifting and mu is None:
|
190 | 201 | raise ValueError(" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`")
|
191 | 202 |
|
192 | 203 | if sigmas is None:
|
193 |
| - self.num_inference_steps = num_inference_steps |
194 | 204 | timesteps = np.linspace(
|
195 | 205 | self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
|
196 | 206 | )
|
197 | 207 |
|
198 | 208 | sigmas = timesteps / self.config.num_train_timesteps
|
| 209 | + else: |
| 210 | + num_inference_steps = len(sigmas) |
| 211 | + self.num_inference_steps = num_inference_steps |
199 | 212 |
|
200 | 213 | if self.config.use_dynamic_shifting:
|
201 | 214 | sigmas = self.time_shift(mu, 1.0, sigmas)
|
202 | 215 | else:
|
203 | 216 | sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
|
204 | 217 |
|
| 218 | + if self.config.use_karras_sigmas: |
| 219 | + sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) |
| 220 | + |
| 221 | + elif self.config.use_exponential_sigmas: |
| 222 | + sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps) |
| 223 | + |
| 224 | + elif self.config.use_beta_sigmas: |
| 225 | + sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps) |
| 226 | + |
205 | 227 | sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
|
206 | 228 | timesteps = sigmas * self.config.num_train_timesteps
|
207 | 229 |
|
@@ -314,5 +336,85 @@ def step(
|
314 | 336 |
|
315 | 337 | return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
|
316 | 338 |
|
| 339 | + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras |
| 340 | + def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: |
| 341 | + """Constructs the noise schedule of Karras et al. (2022).""" |
| 342 | + |
| 343 | + # Hack to make sure that other schedulers which copy this function don't break |
| 344 | + # TODO: Add this logic to the other schedulers |
| 345 | + if hasattr(self.config, "sigma_min"): |
| 346 | + sigma_min = self.config.sigma_min |
| 347 | + else: |
| 348 | + sigma_min = None |
| 349 | + |
| 350 | + if hasattr(self.config, "sigma_max"): |
| 351 | + sigma_max = self.config.sigma_max |
| 352 | + else: |
| 353 | + sigma_max = None |
| 354 | + |
| 355 | + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() |
| 356 | + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() |
| 357 | + |
| 358 | + rho = 7.0 # 7.0 is the value used in the paper |
| 359 | + ramp = np.linspace(0, 1, num_inference_steps) |
| 360 | + min_inv_rho = sigma_min ** (1 / rho) |
| 361 | + max_inv_rho = sigma_max ** (1 / rho) |
| 362 | + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho |
| 363 | + return sigmas |
| 364 | + |
| 365 | + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential |
| 366 | + def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor: |
| 367 | + """Constructs an exponential noise schedule.""" |
| 368 | + |
| 369 | + # Hack to make sure that other schedulers which copy this function don't break |
| 370 | + # TODO: Add this logic to the other schedulers |
| 371 | + if hasattr(self.config, "sigma_min"): |
| 372 | + sigma_min = self.config.sigma_min |
| 373 | + else: |
| 374 | + sigma_min = None |
| 375 | + |
| 376 | + if hasattr(self.config, "sigma_max"): |
| 377 | + sigma_max = self.config.sigma_max |
| 378 | + else: |
| 379 | + sigma_max = None |
| 380 | + |
| 381 | + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() |
| 382 | + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() |
| 383 | + |
| 384 | + sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps)) |
| 385 | + return sigmas |
| 386 | + |
| 387 | + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta |
| 388 | + def _convert_to_beta( |
| 389 | + self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6 |
| 390 | + ) -> torch.Tensor: |
| 391 | + """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)""" |
| 392 | + |
| 393 | + # Hack to make sure that other schedulers which copy this function don't break |
| 394 | + # TODO: Add this logic to the other schedulers |
| 395 | + if hasattr(self.config, "sigma_min"): |
| 396 | + sigma_min = self.config.sigma_min |
| 397 | + else: |
| 398 | + sigma_min = None |
| 399 | + |
| 400 | + if hasattr(self.config, "sigma_max"): |
| 401 | + sigma_max = self.config.sigma_max |
| 402 | + else: |
| 403 | + sigma_max = None |
| 404 | + |
| 405 | + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() |
| 406 | + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() |
| 407 | + |
| 408 | + sigmas = np.array( |
| 409 | + [ |
| 410 | + sigma_min + (ppf * (sigma_max - sigma_min)) |
| 411 | + for ppf in [ |
| 412 | + scipy.stats.beta.ppf(timestep, alpha, beta) |
| 413 | + for timestep in 1 - np.linspace(0, 1, num_inference_steps) |
| 414 | + ] |
| 415 | + ] |
| 416 | + ) |
| 417 | + return sigmas |
| 418 | + |
317 | 419 | def __len__(self):
|
318 | 420 | return self.config.num_train_timesteps
|
0 commit comments