Skip to content

Commit 03b7a84

Browse files
yiyixuxuyiyixuxuayushtuesayushmangalpatrickvonplaten
authored
Add Kandinsky 2.1 (#3308)
add kandinsky2.1 --------- Co-authored-by: yiyixuxu <yixu310@gmail,com> Co-authored-by: Ayush Mangal <[email protected]> Co-authored-by: ayushmangal <[email protected]> Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Sayak Paul <[email protected]>
1 parent f19f128 commit 03b7a84

26 files changed

+5497
-42
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,8 @@
166166
title: DiT
167167
- local: api/pipelines/if
168168
title: IF
169+
- local: api/pipelines/kandinsky
170+
title: Kandinsky
169171
- local: api/pipelines/latent_diffusion
170172
title: Latent Diffusion
171173
- local: api/pipelines/paint_by_example
Lines changed: 306 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,306 @@
1+
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
2+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
3+
the License. You may obtain a copy of the License at
4+
http://www.apache.org/licenses/LICENSE-2.0
5+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
6+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
7+
specific language governing permissions and limitations under the License.
8+
-->
9+
10+
# Kandinsky
11+
12+
## Overview
13+
14+
Kandinsky 2.1 inherits best practices from [DALL-E 2](https://arxiv.org/abs/2204.06125) and [Latent Diffusion](https://huggingface.co/docs/diffusers/api/pipelines/latent_diffusion), while introducing some new ideas.
15+
16+
It uses [CLIP](https://huggingface.co/docs/transformers/model_doc/clip) for encoding images and text, and a diffusion image prior (mapping) between latent spaces of CLIP modalities. This approach enhances the visual performance of the model and unveils new horizons in blending images and text-guided image manipulation.
17+
18+
The Kandinsky model is created by [Arseniy Shakhmatov](https://github.com/cene555), [Anton Razzhigaev](https://github.com/razzant), [Aleksandr Nikolich](https://github.com/AlexWortega), [Igor Pavlov](https://github.com/boomb0om), [Andrey Kuznetsov](https://github.com/kuznetsoffandrey) and [Denis Dimitrov](https://github.com/denndimitrov) and the original codebase can be found [here](https://github.com/ai-forever/Kandinsky-2)
19+
20+
## Available Pipelines:
21+
22+
| Pipeline | Tasks | Colab
23+
|---|---|:---:|
24+
| [pipeline_kandinsky.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py) | *Text-to-Image Generation* | - |
25+
| [pipeline_kandinsky_inpaint.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py) | *Image-Guided Image Generation* | - |
26+
| [pipeline_kandinsky_img2img.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py) | *Image-Guided Image Generation* | - |
27+
28+
## Usage example
29+
30+
In the following, we will walk you through some cool examples of using the Kandinsky pipelines to create some visually aesthetic artwork.
31+
32+
### Text-to-Image Generation
33+
34+
For text-to-image generation, we need to use both [`KandinskyPriorPipeline`] and [`KandinskyPipeline`]. The first step is to encode text prompts with CLIP and then diffuse the CLIP text embeddings to CLIP image embeddings, as first proposed in [DALL-E 2](https://cdn.openai.com/papers/dall-e-2.pdf). Let's throw a fun prompt at Kandinsky to see what it comes up with :)
35+
36+
```python
37+
prompt = "A alien cheeseburger creature eating itself, claymation, cinematic, moody lighting"
38+
negative_prompt = "low quality, bad quality"
39+
```
40+
41+
We will pass both the `prompt` and `negative_prompt` to our prior diffusion pipeline. In contrast to other diffusion pipelines, such as Stable Diffusion, the `prompt` and `negative_prompt` shall be passed separately so that we can retrieve a CLIP image embedding for each prompt input. You can use `guidance_scale`, and `num_inference_steps` arguments to guide this process, just like how you would normally do with all other pipelines in diffusers.
42+
43+
```python
44+
from diffusers import KandinskyPriorPipeline
45+
import torch
46+
47+
# create prior
48+
pipe_prior = KandinskyPriorPipeline.from_pretrained(
49+
"kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16
50+
)
51+
pipe_prior.to("cuda")
52+
53+
generator = torch.Generator(device="cuda").manual_seed(12)
54+
image_emb = pipe_prior(
55+
prompt, guidance_scale=1.0, num_inference_steps=25, generator=generator, negative_prompt=negative_prompt
56+
).images
57+
58+
zero_image_emb = pipe_prior(
59+
negative_prompt, guidance_scale=1.0, num_inference_steps=25, generator=generator, negative_prompt=negative_prompt
60+
).images
61+
```
62+
63+
Once we create the image embedding, we can use [`KandinskyPipeline`] to generate images.
64+
65+
```python
66+
from PIL import Image
67+
from diffusers import KandinskyPipeline
68+
69+
70+
def image_grid(imgs, rows, cols):
71+
assert len(imgs) == rows * cols
72+
73+
w, h = imgs[0].size
74+
grid = Image.new("RGB", size=(cols * w, rows * h))
75+
grid_w, grid_h = grid.size
76+
77+
for i, img in enumerate(imgs):
78+
grid.paste(img, box=(i % cols * w, i // cols * h))
79+
return grid
80+
81+
82+
# create diffuser pipeline
83+
pipe = KandinskyPipeline.from_pretrained("kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16)
84+
pipe.to("cuda")
85+
86+
images = pipe(
87+
prompt,
88+
image_embeds=image_emb,
89+
negative_image_embeds=zero_image_emb,
90+
num_images_per_prompt=2,
91+
height=768,
92+
width=768,
93+
num_inference_steps=100,
94+
guidance_scale=4.0,
95+
generator=generator,
96+
).images
97+
```
98+
99+
One cheeseburger monster coming up! Enjoy!
100+
101+
![img](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/kandinsky-docs/cheeseburger.png)
102+
103+
The Kandinsky model works extremely well with creative prompts. Here is some of the amazing art that can be created using the exact same process but with different prompts.
104+
105+
```python
106+
prompt = "bird eye view shot of a full body woman with cyan light orange magenta makeup, digital art, long braided hair her face separated by makeup in the style of yin Yang surrealism, symmetrical face, real image, contrasting tone, pastel gradient background"
107+
```
108+
![img](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/kandinsky-docs/hair.png)
109+
110+
```python
111+
prompt = "A car exploding into colorful dust"
112+
```
113+
![img](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/kandinsky-docs/dusts.png)
114+
115+
```python
116+
prompt = "editorial photography of an organic, almost liquid smoke style armchair"
117+
```
118+
![img](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/kandinsky-docs/smokechair.png)
119+
120+
```python
121+
prompt = "birds eye view of a quilted paper style alien planet landscape, vibrant colours, Cinematic lighting"
122+
```
123+
![img](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/kandinsky-docs/alienplanet.png)
124+
125+
126+
### Text Guided Image-to-Image Generation
127+
128+
The same Kandinsky model weights can be used for text-guided image-to-image translation. In this case, just make sure to load the weights using the [`KandinskyImg2ImgPipeline`] pipeline.
129+
130+
**Note**: You can also directly move the weights of the text-to-image pipelines to the image-to-image pipelines
131+
without loading them twice by making use of the [`~DiffusionPipeline.components`] function as explained [here](#converting-between-different-pipelines).
132+
133+
Let's download an image.
134+
135+
```python
136+
from PIL import Image
137+
import requests
138+
from io import BytesIO
139+
140+
# download image
141+
url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
142+
response = requests.get(url)
143+
original_image = Image.open(BytesIO(response.content)).convert("RGB")
144+
original_image = original_image.resize((768, 512))
145+
```
146+
147+
![img](https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg)
148+
149+
```python
150+
import torch
151+
from diffusers import KandinskyImg2ImgPipeline, KandinskyPriorPipeline
152+
153+
# create prior
154+
pipe_prior = KandinskyPriorPipeline.from_pretrained(
155+
"kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16
156+
)
157+
pipe_prior.to("cuda")
158+
159+
# create img2img pipeline
160+
pipe = KandinskyImg2ImgPipeline.from_pretrained("kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16)
161+
pipe.to("cuda")
162+
163+
prompt = "A fantasy landscape, Cinematic lighting"
164+
negative_prompt = "low quality, bad quality"
165+
166+
generator = torch.Generator(device="cuda").manual_seed(30)
167+
image_emb = pipe_prior(
168+
prompt, guidance_scale=4.0, num_inference_steps=25, generator=generator, negative_prompt=negative_prompt
169+
).images
170+
171+
zero_image_emb = pipe_prior(
172+
negative_prompt, guidance_scale=4.0, num_inference_steps=25, generator=generator, negative_prompt=negative_prompt
173+
).images
174+
175+
out = pipe(
176+
prompt,
177+
image=original_image,
178+
image_embeds=image_emb,
179+
negative_image_embeds=zero_image_emb,
180+
height=768,
181+
width=768,
182+
num_inference_steps=500,
183+
strength=0.3,
184+
)
185+
186+
out.images[0].save("fantasy_land.png")
187+
```
188+
189+
![img](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/kandinsky-docs/img2img_fantasyland.png)
190+
191+
192+
### Text Guided Inpainting Generation
193+
194+
You can use [`KandinskyInpaintPipeline`] to edit images. In this example, we will add a hat to the portrait of a cat.
195+
196+
```python
197+
from diffusers import KandinskyInpaintPipeline, KandinskyPriorPipeline
198+
from diffusers.utils import load_image
199+
import torch
200+
import numpy as np
201+
202+
pipe_prior = KandinskyPriorPipeline.from_pretrained(
203+
"kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16
204+
)
205+
pipe_prior.to("cuda")
206+
207+
prompt = "a hat"
208+
image_emb, zero_image_emb = pipe_prior(prompt, return_dict=False)
209+
210+
pipe = KandinskyInpaintPipeline.from_pretrained("kandinsky-community/kandinsky-2-1-inpaint", torch_dtype=torch.float16)
211+
pipe.to("cuda")
212+
213+
init_image = load_image(
214+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/kandinsky/cat.png"
215+
)
216+
217+
mask = np.ones((768, 768), dtype=np.float32)
218+
# Let's mask out an area above the cat's head
219+
mask[:250, 250:-250] = 0
220+
221+
out = pipe(
222+
prompt,
223+
image=init_image,
224+
mask_image=mask,
225+
image_embeds=image_emb,
226+
negative_image_embeds=zero_image_emb,
227+
height=768,
228+
width=768,
229+
num_inference_steps=150,
230+
)
231+
232+
image = out.images[0]
233+
image.save("cat_with_hat.png")
234+
```
235+
![img](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/kandinsky-docs/inpaint_cat_hat.png)
236+
237+
### Interpolate
238+
239+
The [`KandinskyPriorPipeline`] also comes with a cool utility function that will allow you to interpolate the latent space of different images and texts super easily. Here is an example of how you can create an Impressionist-style portrait for your pet based on "The Starry Night".
240+
241+
Note that you can interpolate between texts and images - in the below example, we passed a text prompt "a cat" and two images to the `interplate` function, along with a `weights` variable containing the corresponding weights for each condition we interplate.
242+
243+
```python
244+
from diffusers import KandinskyPriorPipeline, KandinskyPipeline
245+
from diffusers.utils import load_image
246+
import PIL
247+
248+
import torch
249+
from torchvision import transforms
250+
251+
pipe_prior = KandinskyPriorPipeline.from_pretrained(
252+
"kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16
253+
)
254+
pipe_prior.to("cuda")
255+
256+
img1 = load_image(
257+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/kandinsky/cat.png"
258+
)
259+
260+
img2 = load_image(
261+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/kandinsky/starry_night.jpeg"
262+
)
263+
264+
# add all the conditions we want to interpolate, can be either text or image
265+
images_texts = ["a cat", img1, img2]
266+
# specify the weights for each condition in images_texts
267+
weights = [0.3, 0.3, 0.4]
268+
image_emb, zero_image_emb = pipe_prior.interpolate(images_texts, weights)
269+
270+
pipe = KandinskyPipeline.from_pretrained("kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16)
271+
pipe.to("cuda")
272+
273+
image = pipe(
274+
"", image_embeds=image_emb, negative_image_embeds=zero_image_emb, height=768, width=768, num_inference_steps=150
275+
).images[0]
276+
277+
image.save("starry_cat.png")
278+
```
279+
![img](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/kandinsky-docs/starry_cat.png)
280+
281+
282+
## KandinskyPriorPipeline
283+
284+
[[autodoc]] KandinskyPriorPipeline
285+
- all
286+
- __call__
287+
- interpolate
288+
289+
## KandinskyPipeline
290+
291+
[[autodoc]] KandinskyPipeline
292+
- all
293+
- __call__
294+
295+
## KandinskyInpaintPipeline
296+
297+
[[autodoc]] KandinskyInpaintPipeline
298+
- all
299+
- __call__
300+
301+
## KandinskyImg2ImgPipeline
302+
303+
[[autodoc]] KandinskyImg2ImgPipeline
304+
- all
305+
- __call__
306+

0 commit comments

Comments
 (0)