@@ -71,6 +71,43 @@ def alpha_bar_fn(t):
71
71
return torch .tensor (betas , dtype = torch .float32 )
72
72
73
73
74
+ # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
75
+ def rescale_zero_terminal_snr (betas ):
76
+ """
77
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
78
+
79
+
80
+ Args:
81
+ betas (`torch.FloatTensor`):
82
+ the betas that the scheduler is being initialized with.
83
+
84
+ Returns:
85
+ `torch.FloatTensor`: rescaled betas with zero terminal SNR
86
+ """
87
+ # Convert betas to alphas_bar_sqrt
88
+ alphas = 1.0 - betas
89
+ alphas_cumprod = torch .cumprod (alphas , dim = 0 )
90
+ alphas_bar_sqrt = alphas_cumprod .sqrt ()
91
+
92
+ # Store old values.
93
+ alphas_bar_sqrt_0 = alphas_bar_sqrt [0 ].clone ()
94
+ alphas_bar_sqrt_T = alphas_bar_sqrt [- 1 ].clone ()
95
+
96
+ # Shift so the last timestep is zero.
97
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
98
+
99
+ # Scale so the first timestep is back to the old value.
100
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T )
101
+
102
+ # Convert alphas_bar_sqrt to betas
103
+ alphas_bar = alphas_bar_sqrt ** 2 # Revert sqrt
104
+ alphas = alphas_bar [1 :] / alphas_bar [:- 1 ] # Revert cumprod
105
+ alphas = torch .cat ([alphas_bar [0 :1 ], alphas ])
106
+ betas = 1 - alphas
107
+
108
+ return betas
109
+
110
+
74
111
class DPMSolverMultistepScheduler (SchedulerMixin , ConfigMixin ):
75
112
"""
76
113
`DPMSolverMultistepScheduler` is a fast dedicated high-order solver for diffusion ODEs.
@@ -144,6 +181,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
144
181
An offset added to the inference steps. You can use a combination of `offset=1` and
145
182
`set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
146
183
Diffusion.
184
+ rescale_betas_zero_snr (`bool`, defaults to `False`):
185
+ Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
186
+ dark samples instead of limiting it to samples with medium brightness. Loosely related to
187
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
147
188
"""
148
189
149
190
_compatibles = [e .name for e in KarrasDiffusionSchedulers ]
@@ -173,6 +214,7 @@ def __init__(
173
214
variance_type : Optional [str ] = None ,
174
215
timestep_spacing : str = "linspace" ,
175
216
steps_offset : int = 0 ,
217
+ rescale_betas_zero_snr : bool = False ,
176
218
):
177
219
if algorithm_type in ["dpmsolver" , "sde-dpmsolver" ]:
178
220
deprecation_message = f"algorithm_type { algorithm_type } is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
@@ -191,8 +233,17 @@ def __init__(
191
233
else :
192
234
raise NotImplementedError (f"{ beta_schedule } does is not implemented for { self .__class__ } " )
193
235
236
+ if rescale_betas_zero_snr :
237
+ self .betas = rescale_zero_terminal_snr (self .betas )
238
+
194
239
self .alphas = 1.0 - self .betas
195
240
self .alphas_cumprod = torch .cumprod (self .alphas , dim = 0 )
241
+
242
+ if rescale_betas_zero_snr :
243
+ # Close to 0 without being 0 so first sigma is not inf
244
+ # FP16 smallest positive subnormal works well here
245
+ self .alphas_cumprod [- 1 ] = 2 ** - 24
246
+
196
247
# Currently we only support VP-type noise schedule
197
248
self .alpha_t = torch .sqrt (self .alphas_cumprod )
198
249
self .sigma_t = torch .sqrt (1 - self .alphas_cumprod )
@@ -895,9 +946,12 @@ def step(
895
946
self .model_outputs [i ] = self .model_outputs [i + 1 ]
896
947
self .model_outputs [- 1 ] = model_output
897
948
949
+ # Upcast to avoid precision issues when computing prev_sample
950
+ sample = sample .to (torch .float32 )
951
+
898
952
if self .config .algorithm_type in ["sde-dpmsolver" , "sde-dpmsolver++" ]:
899
953
noise = randn_tensor (
900
- model_output .shape , generator = generator , device = model_output .device , dtype = model_output . dtype
954
+ model_output .shape , generator = generator , device = model_output .device , dtype = torch . float32
901
955
)
902
956
else :
903
957
noise = None
@@ -912,6 +966,9 @@ def step(
912
966
if self .lower_order_nums < self .config .solver_order :
913
967
self .lower_order_nums += 1
914
968
969
+ # Cast sample back to expected dtype
970
+ prev_sample = prev_sample .to (model_output .dtype )
971
+
915
972
# upon completion increase step index by one
916
973
self ._step_index += 1
917
974
0 commit comments