@@ -51,13 +51,52 @@ def prepare_extra_step_kwargs(self, generator, eta):
51
51
extra_step_kwargs ["generator" ] = generator
52
52
return extra_step_kwargs
53
53
54
+ def get_sigma_min_max_from_scheduler (self ):
55
+ # Get sigma_min, sigma_max in original sigma space, not Karras sigma space
56
+ # (e.g. not exponentiated by 1 / rho)
57
+ if hasattr (self .scheduler , "sigma_min" ):
58
+ sigma_min = self .scheduler .sigma_min
59
+ sigma_max = self .scheduler .sigma_max
60
+ elif hasattr (self .scheduler , "sigmas" ):
61
+ # Karras-style scheduler e.g. (EulerDiscreteScheduler, HeunDiscreteScheduler)
62
+ # Get sigma_min, sigma_max before they're converted into Karras sigma space by set_timesteps
63
+ # TODO: Karras schedulers are inconsistent about how they initialize sigmas in __init__
64
+ # For example, EulerDiscreteScheduler gets sigmas in original sigma space, but HeunDiscreteScheduler
65
+ # initializes it through set_timesteps, which potentially leaves the sigmas in Karras sigma space.
66
+ # TODO: For example, in EulerDiscreteScheduler, a value of 0.0 is appended to the sigmas whern initialized
67
+ # in __init__. But wouldn't we usually want sigma_min to be a small positive number, following the
68
+ # consistency models paper?
69
+ # See e.g. https://github.com/openai/consistency_models/blob/main/scripts/launch.sh#L13
70
+ sigma_min = self .scheduler .sigmas [- 1 ].item ()
71
+ sigma_max = self .scheduler .sigmas [0 ].item ()
72
+ else :
73
+ raise ValueError (
74
+ f"Scheduler { self .scheduler .__class__ } does not have sigma_min or sigma_max."
75
+ )
76
+ return sigma_min , sigma_max
77
+
78
+ def get_sigmas_from_scheduler (self ):
79
+ if hasattr (self .scheduler , "sigmas" ):
80
+ # e.g. HeunDiscreteScheduler
81
+ sigmas = self .scheduler .sigmas
82
+ elif hasattr (self .scheduler , "schedule" ):
83
+ # e.g. KarrasVeScheduler
84
+ sigmas = self .scheduler .schedule
85
+ else :
86
+ raise ValueError (
87
+ f"Scheduler { self .scheduler .__class__ } does not have sigmas."
88
+ )
89
+ return sigmas
90
+
54
91
def get_scalings (self , sigma , sigma_data : float = 0.5 ):
55
92
c_skip = sigma_data ** 2 / (sigma ** 2 + sigma_data ** 2 )
56
93
c_out = sigma * sigma_data / (sigma ** 2 + sigma_data ** 2 ) ** 0.5
57
94
c_in = 1 / (sigma ** 2 + sigma_data ** 2 ) ** 0.5
58
95
return c_skip , c_out , c_in
59
96
60
97
def get_scalings_for_boundary_condition (sigma , sigma_min , sigma_data : float = 0.5 ):
98
+ # sigma_min should be in original sigma space, not in karras sigma space
99
+ # (e.g. not exponentiated by 1 / rho)
61
100
c_skip = sigma_data ** 2 / (
62
101
(sigma - sigma_min ) ** 2 + sigma_data ** 2
63
102
)
@@ -73,6 +112,8 @@ def denoise(self, x_t, sigma, sigma_min, sigma_data: float = 0.5, clip_denoised=
73
112
"""
74
113
Run the consistency model forward...?
75
114
"""
115
+ # sigma_min should be in original sigma space, not in karras sigma space
116
+ # (e.g. not exponentiated by 1 / rho)
76
117
c_skip , c_out , c_in = [
77
118
append_dims (x , x_t .ndim )
78
119
for x in self .get_scalings_for_boundary_condition (sigma , sigma_min , sigma_data = sigma_data )
@@ -88,26 +129,6 @@ def to_d(x, sigma, denoised):
88
129
"""Converts a denoiser output to a Karras ODE derivative."""
89
130
return (x - denoised ) / append_dims (sigma , x .ndim )
90
131
91
- def add_noise_to_input (
92
- self ,
93
- sample : torch .FloatTensor ,
94
- sigma_hat : float ,
95
- sigma_min : float ,
96
- sigma_max : float ,
97
- s_noise : float = 1.0 ,
98
- generator : Optional [torch .Generator ] = None ,
99
- ):
100
- # Clamp sigma_hat
101
- sigma_hat = sigma_hat .clamp (min = sigma_min , max = sigma_max )
102
-
103
- # sample z ~ N(0, s_noise^2 * I)
104
- z = s_noise * randn_tensor (sample .shape , generator = generator , device = sample .device )
105
-
106
- # tau = sigma_hat; eps = sigma_min
107
- sample_hat = sample + ((sigma_hat ** 2 - sigma_min ** 2 ) ** 0.5 * z )
108
-
109
- return sample_hat
110
-
111
132
@torch .no_grad ()
112
133
def __call__ (
113
134
self ,
@@ -144,69 +165,76 @@ def __call__(
144
165
img_size = img_size = self .unet .config .sample_size
145
166
shape = (batch_size , 3 , img_size , img_size )
146
167
device = self .device
147
- scheduler_is_in_sigma_space = hasattr (self .scheduler , "sigmas" )
148
- scheduler_has_sigma_min = hasattr (self .scheduler , "sigma_min" )
149
- assert scheduler_has_sigma_min or scheduler_is_in_sigma_space , "Scheduler needs to have sigmas"
150
168
151
169
# 1. Sample image latents x_0 ~ N(0, sigma_0^2 * I)
152
170
sample = randn_tensor (shape , generator = generator , device = device ) * self .scheduler .init_noise_sigma
153
171
154
172
# 2. Set timesteps and get sigmas
173
+ # Get sigma_min, sigma_max in original sigma space (not Karras sigma space)
174
+ sigma_min , sigma_max = self .get_sigma_min_max_from_scheduler ()
155
175
self .scheduler .set_timesteps (num_inference_steps )
156
176
timesteps = self .scheduler .timesteps
177
+
178
+ # Now get Karras sigma schedule (which I think the original implementation always uses)
179
+ # See https://github.com/openai/consistency_models/blob/main/cm/karras_diffusion.py#L376
180
+ sigmas = self .get_sigmas_from_scheduler ()
157
181
158
182
# 3. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
159
183
extra_step_kwargs = self .prepare_extra_step_kwargs (generator , eta )
160
184
161
185
# 4. Denoising loop
162
- if scheduler_has_sigma_min :
163
- # 4.1 Scheduler which can add noise to input (e.g. KarrasVeScheduler)
164
- sigma_min = self .scheduler .sigma_min
165
- sigma_max = self .scheduler .sigma_max
166
- s_noise = self .scheduler .s_noise
167
- sigmas = self .scheduler .schedule
168
-
186
+ # TODO: hack, is there a better way to identify schedulers that implement the stochastic iterative sampling
187
+ # similar to stochastic_iterative_sampler in the original code?
188
+ if hasattr (self .scheduler , "add_noise_to_input" ):
189
+ # 4.1 Consistency Model Stochastic Iterative Scheduler (multi-step sampling)
169
190
# First evaluate the consistency model. This will be the output sample if num_inference_steps == 1
170
- sigma = sigmas [timesteps [0 ]]
191
+ # TODO: not all schedulers have an index_for_timestep method (e.g. KarrasVeScheduler)
192
+ step_idx = self .scheduler .index_for_timestep (timesteps [0 ])
193
+ sigma = sigmas [step_idx ]
171
194
_ , sample = self .denoise (sample , sigma_min , sigma_data = sigma_data , clip_denoised = clip_denoised )
172
195
173
196
# If num_inference_steps > 1, perform multi-step sampling (stochastic_iterative_sampler)
174
- # Alternate adding noise and evaluating the consistency model
197
+ # Alternate adding noise and evaluating the consistency model on the noised input
175
198
for i , t in self .progress_bar (enumerate (self .scheduler .timesteps [1 :])):
176
- sigma = sigmas [t ]
177
- sigma_prev = sigmas [t - 1 ]
178
- if hasattr (self .scheduler , "add_noise_to_input" ):
179
- sample_hat = self .scheduler .add_noise_to_input (sample , sigma , generator = generator )[0 ]
180
- else :
181
- sample_hat = self .add_noise_to_input (sample , sigma , sigma_prev , sigma_min , sigma_max , s_noise = s_noise , generator = generator )
182
-
183
- _ , sample = self .denoise (sample_hat , sigma , sigma_min , sigma_data = sigma_data , clip_denoised = clip_denoised )
184
- else :
199
+ step_idx = self .scheduler .index_for_timestep (t )
200
+ sigma = sigmas [step_idx ]
201
+ sigma_prev = sigmas [step_idx - 1 ]
202
+ sample_hat , sigma_hat = self .scheduler .add_noise_to_input (sample , sigma , generator = generator )[0 ]
203
+
204
+ model_output , denoised = self .denoise (
205
+ sample_hat , sigma , sigma_min , sigma_data = sigma_data , clip_denoised = clip_denoised
206
+ )
207
+
208
+ sample = self .scheduler .step (denoised , sigma_hat , sigma_prev , sample_hat ).prev_sample
209
+ elif hasattr (self .scheduler , "sigmas" ):
185
210
# 4.2 Karras-style scheduler in sigma space (e.g. HeunDiscreteScheduler)
186
- sigma_min = self .scheduler .sigmas [- 1 ]
187
211
# TODO: warmup steps logic correct?
188
212
num_warmup_steps = len (timesteps ) - num_inference_steps * self .scheduler .order
189
213
with self .progress_bar (total = num_inference_steps ) as progress_bar :
190
214
for i , t in enumerate (timesteps ):
191
215
step_idx = self .scheduler .index_for_timestep (t )
192
216
sigma = self .scheduler .sigmas [step_idx ]
193
217
# TODO: handle class labels?
218
+ # TODO: check shapes, might need equivalent of s_in in original code
219
+ # See e.g. https://github.com/openai/consistency_models/blob/main/cm/karras_diffusion.py#L510
194
220
model_output , denoised = self .denoise (
195
221
sample , sigma , sigma_min , sigma_data = sigma_data , clip_denoised = clip_denoised
196
222
)
197
223
198
224
# Karras-style schedulers already convert to a ODE derivative inside step()
199
225
sample = self .scheduler .step (denoised , t , sample , ** extra_step_kwargs ).prev_sample
200
226
201
- # TODO: need to handle karras sigma stuff here?
202
-
203
- # TODO: differs from callback support in original code
227
+ # Note: differs from callback support in original code
204
228
# See e.g. https://github.com/openai/consistency_models/blob/main/cm/karras_diffusion.py#L459
205
229
# call the callback, if provided
206
230
if i == len (timesteps ) - 1 or ((i + 1 ) > num_warmup_steps and (i + 1 ) % self .scheduler .order == 0 ):
207
231
progress_bar .update ()
208
232
if callback is not None and i % callback_steps == 0 :
209
233
callback (i , t , sample )
234
+ else :
235
+ raise ValueError (
236
+ f"Scheduler { self .scheduler .__class__ } is not compatible with consistency models."
237
+ )
210
238
211
239
# 5. Post-process image sample
212
240
sample = (sample / 2 + 0.5 ).clamp (0 , 1 )
0 commit comments