Skip to content

Commit e222246

Browse files
authored
Fix sigma_last with use_flow_sigmas (#10267)
1 parent 83709d5 commit e222246

File tree

4 files changed

+14
-0
lines changed

4 files changed

+14
-0
lines changed

src/diffusers/schedulers/scheduling_deis_multistep.py

+1
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
289289
sigmas = 1.0 - alphas
290290
sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy()
291291
timesteps = (sigmas * self.config.num_train_timesteps).copy()
292+
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
292293
else:
293294
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
294295
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5

src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py

+3
Original file line numberDiff line numberDiff line change
@@ -291,14 +291,17 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc
291291
elif self.config.use_exponential_sigmas:
292292
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
293293
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
294+
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
294295
elif self.config.use_beta_sigmas:
295296
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
296297
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
298+
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
297299
elif self.config.use_flow_sigmas:
298300
alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)
299301
sigmas = 1.0 - alphas
300302
sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy()
301303
timesteps = (sigmas * self.config.num_train_timesteps).copy()
304+
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
302305
else:
303306
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
304307
sigma_max = (

src/diffusers/schedulers/scheduling_sasolver.py

+1
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,7 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc
318318
sigmas = 1.0 - alphas
319319
sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy()
320320
timesteps = (sigmas * self.config.num_train_timesteps).copy()
321+
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
321322
else:
322323
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
323324
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5

src/diffusers/schedulers/scheduling_unipc_multistep.py

+9
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,15 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
381381
sigmas = 1.0 - alphas
382382
sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy()
383383
timesteps = (sigmas * self.config.num_train_timesteps).copy()
384+
if self.config.final_sigmas_type == "sigma_min":
385+
sigma_last = sigmas[-1]
386+
elif self.config.final_sigmas_type == "zero":
387+
sigma_last = 0
388+
else:
389+
raise ValueError(
390+
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
391+
)
392+
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
384393
else:
385394
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
386395
if self.config.final_sigmas_type == "sigma_min":

0 commit comments

Comments
 (0)