Skip to content

Commit 2e31a75

Browse files
authored
DPMSolverMultistep add rescale_betas_zero_snr (huggingface#7097)
* DPMMultistep rescale_betas_zero_snr * DPM upcast samples in step() * DPM rescale_betas_zero_snr UT * DPMSolverMulti move sample upcast after model convert Avoids having to re-use the dtype. * Add a newline for Ruff
1 parent e51862b commit 2e31a75

File tree

2 files changed

+62
-1
lines changed

2 files changed

+62
-1
lines changed

Diff for: src/diffusers/schedulers/scheduling_dpmsolver_multistep.py

+58-1
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,43 @@ def alpha_bar_fn(t):
7171
return torch.tensor(betas, dtype=torch.float32)
7272

7373

74+
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
75+
def rescale_zero_terminal_snr(betas):
76+
"""
77+
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
78+
79+
80+
Args:
81+
betas (`torch.FloatTensor`):
82+
the betas that the scheduler is being initialized with.
83+
84+
Returns:
85+
`torch.FloatTensor`: rescaled betas with zero terminal SNR
86+
"""
87+
# Convert betas to alphas_bar_sqrt
88+
alphas = 1.0 - betas
89+
alphas_cumprod = torch.cumprod(alphas, dim=0)
90+
alphas_bar_sqrt = alphas_cumprod.sqrt()
91+
92+
# Store old values.
93+
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
94+
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
95+
96+
# Shift so the last timestep is zero.
97+
alphas_bar_sqrt -= alphas_bar_sqrt_T
98+
99+
# Scale so the first timestep is back to the old value.
100+
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
101+
102+
# Convert alphas_bar_sqrt to betas
103+
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
104+
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
105+
alphas = torch.cat([alphas_bar[0:1], alphas])
106+
betas = 1 - alphas
107+
108+
return betas
109+
110+
74111
class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
75112
"""
76113
`DPMSolverMultistepScheduler` is a fast dedicated high-order solver for diffusion ODEs.
@@ -144,6 +181,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
144181
An offset added to the inference steps. You can use a combination of `offset=1` and
145182
`set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
146183
Diffusion.
184+
rescale_betas_zero_snr (`bool`, defaults to `False`):
185+
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
186+
dark samples instead of limiting it to samples with medium brightness. Loosely related to
187+
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
147188
"""
148189

149190
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
@@ -173,6 +214,7 @@ def __init__(
173214
variance_type: Optional[str] = None,
174215
timestep_spacing: str = "linspace",
175216
steps_offset: int = 0,
217+
rescale_betas_zero_snr: bool = False,
176218
):
177219
if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
178220
deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
@@ -191,8 +233,17 @@ def __init__(
191233
else:
192234
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
193235

236+
if rescale_betas_zero_snr:
237+
self.betas = rescale_zero_terminal_snr(self.betas)
238+
194239
self.alphas = 1.0 - self.betas
195240
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
241+
242+
if rescale_betas_zero_snr:
243+
# Close to 0 without being 0 so first sigma is not inf
244+
# FP16 smallest positive subnormal works well here
245+
self.alphas_cumprod[-1] = 2**-24
246+
196247
# Currently we only support VP-type noise schedule
197248
self.alpha_t = torch.sqrt(self.alphas_cumprod)
198249
self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
@@ -895,9 +946,12 @@ def step(
895946
self.model_outputs[i] = self.model_outputs[i + 1]
896947
self.model_outputs[-1] = model_output
897948

949+
# Upcast to avoid precision issues when computing prev_sample
950+
sample = sample.to(torch.float32)
951+
898952
if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
899953
noise = randn_tensor(
900-
model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
954+
model_output.shape, generator=generator, device=model_output.device, dtype=torch.float32
901955
)
902956
else:
903957
noise = None
@@ -912,6 +966,9 @@ def step(
912966
if self.lower_order_nums < self.config.solver_order:
913967
self.lower_order_nums += 1
914968

969+
# Cast sample back to expected dtype
970+
prev_sample = prev_sample.to(model_output.dtype)
971+
915972
# upon completion increase step index by one
916973
self._step_index += 1
917974

Diff for: tests/schedulers/test_scheduler_dpm_multi.py

+4
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,10 @@ def test_inference_steps(self):
213213
for num_inference_steps in [1, 2, 3, 5, 10, 50, 100, 999, 1000]:
214214
self.check_over_forward(num_inference_steps=num_inference_steps, time_step=0)
215215

216+
def test_rescale_betas_zero_snr(self):
217+
for rescale_betas_zero_snr in [True, False]:
218+
self.check_over_configs(rescale_betas_zero_snr=rescale_betas_zero_snr)
219+
216220
def test_full_loop_no_noise(self):
217221
sample = self.full_loop()
218222
result_mean = torch.mean(torch.abs(sample))

0 commit comments

Comments
 (0)