@@ -117,6 +117,10 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
117
117
lower_order_final (`bool`, default `True`):
118
118
whether to use lower-order solvers in the final steps. For singlestep schedulers, we recommend to enable
119
119
this to use up all the function evaluations.
120
+ use_karras_sigmas (`bool`, *optional*, defaults to `False`):
121
+ This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the
122
+ noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence
123
+ of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf.
120
124
lambda_min_clipped (`float`, default `-inf`):
121
125
the clipping threshold for the minimum value of lambda(t) for numerical stability. This is critical for
122
126
cosine (squaredcos_cap_v2) noise schedule.
@@ -150,6 +154,7 @@ def __init__(
150
154
algorithm_type : str = "dpmsolver++" ,
151
155
solver_type : str = "midpoint" ,
152
156
lower_order_final : bool = True ,
157
+ use_karras_sigmas : Optional [bool ] = False ,
153
158
lambda_min_clipped : float = - float ("inf" ),
154
159
variance_type : Optional [str ] = None ,
155
160
):
@@ -197,6 +202,7 @@ def __init__(
197
202
self .model_outputs = [None ] * solver_order
198
203
self .sample = None
199
204
self .order_list = self .get_order_list (num_train_timesteps )
205
+ self .use_karras_sigmas = use_karras_sigmas
200
206
201
207
def get_order_list (self , num_inference_steps : int ) -> List [int ]:
202
208
"""
@@ -252,6 +258,14 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
252
258
.copy ()
253
259
.astype (np .int64 )
254
260
)
261
+
262
+ if self .use_karras_sigmas :
263
+ sigmas = np .array (((1 - self .alphas_cumprod ) / self .alphas_cumprod ) ** 0.5 )
264
+ log_sigmas = np .log (sigmas )
265
+ sigmas = self ._convert_to_karras (in_sigmas = sigmas , num_inference_steps = num_inference_steps )
266
+ timesteps = np .array ([self ._sigma_to_t (sigma , log_sigmas ) for sigma in sigmas ]).round ()
267
+ timesteps = np .flip (timesteps ).copy ().astype (np .int64 )
268
+
255
269
self .timesteps = torch .from_numpy (timesteps ).to (device )
256
270
self .model_outputs = [None ] * self .config .solver_order
257
271
self .sample = None
@@ -299,6 +313,44 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
299
313
300
314
return sample
301
315
316
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
317
+ def _sigma_to_t (self , sigma , log_sigmas ):
318
+ # get log sigma
319
+ log_sigma = np .log (sigma )
320
+
321
+ # get distribution
322
+ dists = log_sigma - log_sigmas [:, np .newaxis ]
323
+
324
+ # get sigmas range
325
+ low_idx = np .cumsum ((dists >= 0 ), axis = 0 ).argmax (axis = 0 ).clip (max = log_sigmas .shape [0 ] - 2 )
326
+ high_idx = low_idx + 1
327
+
328
+ low = log_sigmas [low_idx ]
329
+ high = log_sigmas [high_idx ]
330
+
331
+ # interpolate sigmas
332
+ w = (low - log_sigma ) / (low - high )
333
+ w = np .clip (w , 0 , 1 )
334
+
335
+ # transform interpolation to time range
336
+ t = (1 - w ) * low_idx + w * high_idx
337
+ t = t .reshape (sigma .shape )
338
+ return t
339
+
340
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
341
+ def _convert_to_karras (self , in_sigmas : torch .FloatTensor , num_inference_steps ) -> torch .FloatTensor :
342
+ """Constructs the noise schedule of Karras et al. (2022)."""
343
+
344
+ sigma_min : float = in_sigmas [- 1 ].item ()
345
+ sigma_max : float = in_sigmas [0 ].item ()
346
+
347
+ rho = 7.0 # 7.0 is the value used in the paper
348
+ ramp = np .linspace (0 , 1 , num_inference_steps )
349
+ min_inv_rho = sigma_min ** (1 / rho )
350
+ max_inv_rho = sigma_max ** (1 / rho )
351
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho )) ** rho
352
+ return sigmas
353
+
302
354
def convert_model_output (
303
355
self , model_output : torch .FloatTensor , timestep : int , sample : torch .FloatTensor
304
356
) -> torch .FloatTensor :
0 commit comments