Skip to content

Commit 799ab23

Browse files
committed
initial commit
1 parent 29b1325 commit 799ab23

File tree

4 files changed

+157
-0
lines changed

4 files changed

+157
-0
lines changed

src/diffusers/pipelines/consistency_models/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
import inspect
2+
from typing import List, Optional, Tuple, Union
3+
4+
import torch
5+
6+
from ...models import UNet2DConditionModel
7+
from ...schedulers import KarrasDiffusionSchedulers
8+
from ...utils import randn_tensor
9+
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
10+
11+
class ConsistencyModelPipeline(DiffusionPipeline):
12+
r"""
13+
TODO
14+
"""
15+
def __init__(self, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers) -> None:
16+
super().__init__()
17+
18+
self.register_modules(
19+
unet=unet,
20+
scheduler=scheduler,
21+
)
22+
23+
# Need to handle boundary conditions (e.g. c_skip, c_out, etc.) somewhere.
24+
25+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
26+
def prepare_extra_step_kwargs(self, generator, eta):
27+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
28+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
29+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
30+
# and should be between [0, 1]
31+
32+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
33+
extra_step_kwargs = {}
34+
if accepts_eta:
35+
extra_step_kwargs["eta"] = eta
36+
37+
# check if the scheduler accepts generator
38+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
39+
if accepts_generator:
40+
extra_step_kwargs["generator"] = generator
41+
return extra_step_kwargs
42+
43+
def add_noise_to_input(
44+
self,
45+
sample: torch.FloatTensor,
46+
generator: Optional[torch.Generator] = None,
47+
step: int = 0
48+
):
49+
"""
50+
Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a
51+
higher noise level sigma_hat = sigma_i + gamma_i*sigma_i.
52+
TODO Args:
53+
"""
54+
pass
55+
56+
57+
@torch.no_grad()
58+
def __call__(
59+
self,
60+
batch_size: int = 1,
61+
num_inference_steps: int = 2000,
62+
eta: float = 0.0,
63+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
64+
output_type: Optional[str] = "pil",
65+
return_dict: bool = True,
66+
**kwargs,
67+
):
68+
r"""
69+
Args:
70+
batch_size (`int`, *optional*, defaults to 1):
71+
The number of images to generate.
72+
eta (`float`, *optional*, defaults to 0.0):
73+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
74+
[`schedulers.DDIMScheduler`], will be ignored for others.
75+
generator (`torch.Generator`, *optional*):
76+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
77+
to make generation deterministic.
78+
output_type (`str`, *optional*, defaults to `"pil"`):
79+
The output format of the generate image. Choose between
80+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
81+
return_dict (`bool`, *optional*, defaults to `True`):
82+
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
83+
Returns:
84+
[`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is
85+
True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
86+
"""
87+
img_size = img_size = self.unet.config.sample_size
88+
shape = (batch_size, 3, img_size, img_size)
89+
device = self.device
90+
91+
# 1. Sample image latents x_0 ~ N(0, sigma_0^2 * I)
92+
sample = randn_tensor(shape, generator=generator, device=device) * self.scheduler.init_noise_sigma
93+
94+
# 2. Set timesteps
95+
self.scheduler.set_timesteps(num_inference_steps)
96+
# TODO: should schedulers always have sigmas? I think the original code always uses sigmas
97+
# self.scheduler.set_sigmas(num_inference_steps)
98+
99+
# 3. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
100+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
101+
102+
# 4. Denoising loop
103+
# num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
104+
with self.progress_bar(total=num_inference_steps) as progress_bar:
105+
for i, t in enumerate(self.scheduler.timesteps):
106+
# TODO: handle class labels?
107+
model_output = self.unet(sample, t)
108+
109+
sample = self.scheduler.step(model_output, t, sample, **extra_step_kwargs).prev_sample
110+
111+
# TODO: need to handle karras sigma stuff here?
112+
113+
# TODO: need to support callbacks?
114+
115+
# 5. Post-process image sample
116+
sample = sample.clamp(0, 1)
117+
sample = sample.cpu().permute(0, 2, 3, 1).numpy()
118+
119+
if output_type == "pil":
120+
sample = self.numpy_to_pil(sample)
121+
122+
if not return_dict:
123+
return (sample,)
124+
125+
# TODO: Offload to cpu?
126+
127+
return ImagePipelineOutput(images=sample)
128+
129+
130+
131+

tests/pipelines/consistency_models/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import gc
2+
import random
3+
import unittest
4+
5+
import numpy as np
6+
import torch
7+
from PIL import Image
8+
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
9+
10+
from diffusers.utils import floats_tensor, load_image, slow, torch_device
11+
from diffusers.utils.testing_utils import require_torch_gpu
12+
13+
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
14+
15+
class ConsistencyModelPipelineFastTests(
16+
PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase
17+
):
18+
pass
19+
20+
@slow
21+
@require_torch_gpu
22+
class ConsistencyModelPipelineSlowTests(unittest.TestCase):
23+
def tearDown(self):
24+
super().tearDown()
25+
gc.collect()
26+
torch.cuda.empty_cache()

0 commit comments

Comments
 (0)