1
1
import inspect
2
- from typing import List , Optional , Tuple , Union
2
+ from typing import List , Optional , Tuple , Union , Callable
3
3
4
4
import torch
5
5
8
8
from ...utils import randn_tensor
9
9
from ..pipeline_utils import DiffusionPipeline , ImagePipelineOutput
10
10
11
+
12
+ def append_dims (x , target_dims ):
13
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
14
+ dims_to_append = target_dims - x .ndim
15
+ if dims_to_append < 0 :
16
+ raise ValueError (
17
+ f"input has { x .ndim } dims but target_dims is { target_dims } , which is less"
18
+ )
19
+ return x [(...,) + (None ,) * dims_to_append ]
20
+
21
+
11
22
class ConsistencyModelPipeline (DiffusionPipeline ):
12
23
r"""
13
24
TODO
@@ -40,30 +51,76 @@ def prepare_extra_step_kwargs(self, generator, eta):
40
51
extra_step_kwargs ["generator" ] = generator
41
52
return extra_step_kwargs
42
53
54
+ def get_scalings (self , sigma , sigma_data : float = 0.5 ):
55
+ c_skip = sigma_data ** 2 / (sigma ** 2 + sigma_data ** 2 )
56
+ c_out = sigma * sigma_data / (sigma ** 2 + sigma_data ** 2 ) ** 0.5
57
+ c_in = 1 / (sigma ** 2 + sigma_data ** 2 ) ** 0.5
58
+ return c_skip , c_out , c_in
59
+
60
+ def get_scalings_for_boundary_condition (sigma , sigma_min , sigma_data : float = 0.5 ):
61
+ c_skip = sigma_data ** 2 / (
62
+ (sigma - sigma_min ) ** 2 + sigma_data ** 2
63
+ )
64
+ c_out = (
65
+ (sigma - sigma_min )
66
+ * sigma_data
67
+ / (sigma ** 2 + sigma_data ** 2 ) ** 0.5
68
+ )
69
+ c_in = 1 / (sigma ** 2 + sigma_data ** 2 ) ** 0.5
70
+ return c_skip , c_out , c_in
71
+
72
+ def denoise (self , x_t , sigma , sigma_min , sigma_data : float = 0.5 , clip_denoised = True ):
73
+ """
74
+ Run the consistency model forward...?
75
+ """
76
+ c_skip , c_out , c_in = [
77
+ append_dims (x , x_t .ndim )
78
+ for x in self .get_scalings_for_boundary_condition (sigma , sigma_min , sigma_data = sigma_data )
79
+ ]
80
+ rescaled_t = 1000 * 0.25 * torch .log (sigma + 1e-44 )
81
+ model_output = self .unet (c_in * x_t , rescaled_t ).sample
82
+ denoised = c_out * model_output + c_skip * x_t
83
+ if clip_denoised :
84
+ denoised = denoised .clamp (- 1 , 1 )
85
+ return model_output , denoised
86
+
87
+ def to_d (x , sigma , denoised ):
88
+ """Converts a denoiser output to a Karras ODE derivative."""
89
+ return (x - denoised ) / append_dims (sigma , x .ndim )
90
+
43
91
def add_noise_to_input (
44
- self ,
45
- sample : torch .FloatTensor ,
46
- generator : Optional [torch .Generator ] = None ,
47
- step : int = 0
48
- ):
49
- """
50
- Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a
51
- higher noise level sigma_hat = sigma_i + gamma_i*sigma_i.
52
- TODO Args:
53
- """
54
- pass
92
+ self ,
93
+ sample : torch .FloatTensor ,
94
+ sigma_hat : float ,
95
+ sigma_min : float ,
96
+ sigma_max : float ,
97
+ s_noise : float = 1.0 ,
98
+ generator : Optional [torch .Generator ] = None ,
99
+ ):
100
+ # Clamp sigma_hat
101
+ sigma_hat = sigma_hat .clamp (min = sigma_min , max = sigma_max )
55
102
103
+ # sample z ~ N(0, s_noise^2 * I)
104
+ z = s_noise * randn_tensor (sample .shape , generator = generator , device = sample .device )
105
+
106
+ # tau = sigma_hat; eps = sigma_min
107
+ sample_hat = sample + ((sigma_hat ** 2 - sigma_min ** 2 ) ** 0.5 * z )
108
+
109
+ return sample_hat
56
110
57
111
@torch .no_grad ()
58
112
def __call__ (
59
113
self ,
60
114
batch_size : int = 1 ,
61
- num_inference_steps : int = 2000 ,
115
+ num_inference_steps : int = 40 ,
116
+ clip_denoised : bool = True ,
117
+ sigma_data : float = 0.5 ,
62
118
eta : float = 0.0 ,
63
119
generator : Optional [Union [torch .Generator , List [torch .Generator ]]] = None ,
64
120
output_type : Optional [str ] = "pil" ,
65
121
return_dict : bool = True ,
66
- ** kwargs ,
122
+ callback : Optional [Callable [[int , int , torch .FloatTensor ], None ]] = None ,
123
+ callback_steps : int = 1 ,
67
124
):
68
125
r"""
69
126
Args:
@@ -87,33 +144,72 @@ def __call__(
87
144
img_size = img_size = self .unet .config .sample_size
88
145
shape = (batch_size , 3 , img_size , img_size )
89
146
device = self .device
147
+ scheduler_is_in_sigma_space = hasattr (self .scheduler , "sigmas" )
148
+ scheduler_has_sigma_min = hasattr (self .scheduler , "sigma_min" )
149
+ assert scheduler_has_sigma_min or scheduler_is_in_sigma_space , "Scheduler needs to have sigmas"
90
150
91
151
# 1. Sample image latents x_0 ~ N(0, sigma_0^2 * I)
92
152
sample = randn_tensor (shape , generator = generator , device = device ) * self .scheduler .init_noise_sigma
93
153
94
- # 2. Set timesteps
154
+ # 2. Set timesteps and get sigmas
95
155
self .scheduler .set_timesteps (num_inference_steps )
96
- # TODO: should schedulers always have sigmas? I think the original code always uses sigmas
97
- # self.scheduler.set_sigmas(num_inference_steps)
156
+ timesteps = self .scheduler .timesteps
98
157
99
158
# 3. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
100
159
extra_step_kwargs = self .prepare_extra_step_kwargs (generator , eta )
101
160
102
161
# 4. Denoising loop
103
- # num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
104
- with self .progress_bar (total = num_inference_steps ) as progress_bar :
105
- for i , t in enumerate (self .scheduler .timesteps ):
106
- # TODO: handle class labels?
107
- model_output = self .unet (sample , t )
108
-
109
- sample = self .scheduler .step (model_output , t , sample , ** extra_step_kwargs ).prev_sample
110
-
111
- # TODO: need to handle karras sigma stuff here?
112
-
113
- # TODO: need to support callbacks?
162
+ if scheduler_has_sigma_min :
163
+ # 4.1 Scheduler which can add noise to input (e.g. KarrasVeScheduler)
164
+ sigma_min = self .scheduler .sigma_min
165
+ sigma_max = self .scheduler .sigma_max
166
+ s_noise = self .scheduler .s_noise
167
+ sigmas = self .scheduler .schedule
168
+
169
+ # First evaluate the consistency model. This will be the output sample if num_inference_steps == 1
170
+ sigma = sigmas [timesteps [0 ]]
171
+ _ , sample = self .denoise (sample , sigma_min , sigma_data = sigma_data , clip_denoised = clip_denoised )
172
+
173
+ # If num_inference_steps > 1, perform multi-step sampling (stochastic_iterative_sampler)
174
+ # Alternate adding noise and evaluating the consistency model
175
+ for i , t in self .progress_bar (enumerate (self .scheduler .timesteps [1 :])):
176
+ sigma = sigmas [t ]
177
+ sigma_prev = sigmas [t - 1 ]
178
+ if hasattr (self .scheduler , "add_noise_to_input" ):
179
+ sample_hat = self .scheduler .add_noise_to_input (sample , sigma , generator = generator )[0 ]
180
+ else :
181
+ sample_hat = self .add_noise_to_input (sample , sigma , sigma_prev , sigma_min , sigma_max , s_noise = s_noise , generator = generator )
182
+
183
+ _ , sample = self .denoise (sample_hat , sigma , sigma_min , sigma_data = sigma_data , clip_denoised = clip_denoised )
184
+ else :
185
+ # 4.2 Karras-style scheduler in sigma space (e.g. HeunDiscreteScheduler)
186
+ sigma_min = self .scheduler .sigmas [- 1 ]
187
+ # TODO: warmup steps logic correct?
188
+ num_warmup_steps = len (timesteps ) - num_inference_steps * self .scheduler .order
189
+ with self .progress_bar (total = num_inference_steps ) as progress_bar :
190
+ for i , t in enumerate (timesteps ):
191
+ step_idx = self .scheduler .index_for_timestep (t )
192
+ sigma = self .scheduler .sigmas [step_idx ]
193
+ # TODO: handle class labels?
194
+ model_output , denoised = self .denoise (
195
+ sample , sigma , sigma_min , sigma_data = sigma_data , clip_denoised = clip_denoised
196
+ )
197
+
198
+ # Karras-style schedulers already convert to a ODE derivative inside step()
199
+ sample = self .scheduler .step (denoised , t , sample , ** extra_step_kwargs ).prev_sample
200
+
201
+ # TODO: need to handle karras sigma stuff here?
202
+
203
+ # TODO: differs from callback support in original code
204
+ # See e.g. https://github.com/openai/consistency_models/blob/main/cm/karras_diffusion.py#L459
205
+ # call the callback, if provided
206
+ if i == len (timesteps ) - 1 or ((i + 1 ) > num_warmup_steps and (i + 1 ) % self .scheduler .order == 0 ):
207
+ progress_bar .update ()
208
+ if callback is not None and i % callback_steps == 0 :
209
+ callback (i , t , sample )
114
210
115
211
# 5. Post-process image sample
116
- sample = sample .clamp (0 , 1 )
212
+ sample = ( sample / 2 + 0.5 ) .clamp (0 , 1 )
117
213
sample = sample .cpu ().permute (0 , 2 , 3 , 1 ).numpy ()
118
214
119
215
if output_type == "pil" :
@@ -125,7 +221,3 @@ def __call__(
125
221
# TODO: Offload to cpu?
126
222
127
223
return ImagePipelineOutput (images = sample )
128
-
129
-
130
-
131
-
0 commit comments