Skip to content

Commit 675f0d1

Browse files
committed
Quick implementation of t2i-adapter
Load adapter module with from_pretrained Prototyping generalized adapter framework Writeup doc string for sideload framework(WIP) + some minor update on implementation Update adapter models Remove old adapter optional args in UNet Add StableDiffusionAdapterPipeline unit test Handle cpu offload in StableDiffusionAdapterPipeline Auto correct coding style Update model repo name to "RzZ/sd-v1-4-adapter-pipeline" Refactor MultiAdapter to better compatible with config system Export MultiAdapter Create pipeline document template from controlnet Create dummy objects Supproting new AdapterLight model Fix StableDiffusionAdapterPipeline common pipeline test [WIP] Update adapter pipeline document Handle num_inference_steps in StableDiffusionAdapterPipeline Update definition of Adapter "channels_in" Update documents Apply code style Fix doc typo and merge error Update doc string and example Quality of life improvement Remove redundant code and file from prototyping Remove unused pageage Remove comments Fix title Fix typo Add conditioning scale arg Bring back old implmentation Offload sideload Add supply info on document Update src/diffusers/models/adapter.py Co-authored-by: Will Berman <[email protected]> Update MultiAdapter constructor Swap out custom checkpoint and update pipeline constructor Update docment Apply suggestions from code review Co-authored-by: Will Berman <[email protected]> Correcting style Following single-file policy Update auto size in image preprocess func Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_adapter.py Co-authored-by: Will Berman <[email protected]> fix copies Update adapter pipeline behavior Add adapter_conditioning_scale doc string Add the missing doc string Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> Fix few bugs from suggestion Handle L-mode PIL image as control image Rename to differentiate adapter resblock Update src/diffusers/models/adapter.py Co-authored-by: Sayak Paul <[email protected]> Fix typo Update adapter parameter name Update test case and code style Fix copies Fix typo Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_adapter.py Co-authored-by: Will Berman <[email protected]> Update Adapter class name Add checkpoint converting script Fix style Fix-copies
1 parent bdeff4d commit 675f0d1

21 files changed

+2071
-7
lines changed

docs/source/en/_toctree.yml

+2
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,8 @@
197197
title: Text-to-Image Generation with ControlNet Conditioning
198198
- local: api/pipelines/stable_diffusion/model_editing
199199
title: Text-to-Image Model Editing
200+
- local: api/pipelines/stable_diffusion/adapter
201+
title: Text-to-Image Generation with T2I Adapter Conditioning
200202
title: Stable Diffusion
201203
- local: api/pipelines/stable_diffusion_2
202204
title: Stable Diffusion 2

docs/source/en/api/models.mdx

+6
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,12 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module
8282
## ControlNetModel
8383
[[autodoc]] ControlNetModel
8484

85+
## T2IAdapter
86+
[[autodoc]] T2IAdapter
87+
88+
## MultiAdapter
89+
[[autodoc]] MultiAdapter
90+
8591
## FlaxModelMixin
8692
[[autodoc]] FlaxModelMixin
8793

docs/source/en/api/pipelines/overview.mdx

+1
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ available a colab notebook to directly try them out.
5959
| [score_sde_ve](./score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
6060
| [score_sde_vp](./score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
6161
| [semantic_stable_diffusion](./semantic_stable_diffusion) | [**SEGA: Instructing Diffusion using Semantic Dimensions**](https://arxiv.org/abs/2301.12247) | Text-to-Image Generation |
62+
| [stable_diffusion_adapter](./stable_diffusion/adapter) | [**T2I-Adapter**](https://arxiv.org/abs/2302.08453) | Image-to-Image Text-Guided Generation | -
6263
| [stable_diffusion_text2img](./stable_diffusion/text2img) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
6364
| [stable_diffusion_img2img](./stable_diffusion/img2img) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
6465
| [stable_diffusion_inpaint](./stable_diffusion/inpaint) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
13+
# Text-to-Image Generation with Adapter Conditioning
14+
15+
## Overview
16+
17+
[T2I-Adapter: Learning Adapters to Dig out More Controllable Ability for Text-to-Image Diffusion Models](https://arxiv.org/abs/2302.08453) by Chong Mou, Xintao Wang, Liangbin Xie, Jian Zhang, Zhongang Qi, Ying Shan, Xiaohu Qie.
18+
19+
Using the pretrained models we can provide control images (for example, a depth map) to control Stable Diffusion text-to-image generation so that it follows the structure of the depth image and fills in the details.
20+
21+
The abstract of the paper is the following:
22+
23+
*The incredible generative ability of large-scale text-to-image (T2I) models has demonstrated strong power of learning complex structures and meaningful semantics. However, relying solely on text prompts cannot fully take advantage of the knowledge learned by the model, especially when flexible and accurate structure control is needed. In this paper, we aim to ``dig out" the capabilities that T2I models have implicitly learned, and then explicitly use them to control the generation more granularly. Specifically, we propose to learn simple and small T2I-Adapters to align internal knowledge in T2I models with external control signals, while freezing the original large T2I models. In this way, we can train various adapters according to different conditions, and achieve rich control and editing effects. Further, the proposed T2I-Adapters have attractive properties of practical value, such as composability and generalization ability. Extensive experiments demonstrate that our T2I-Adapter has promising generation quality and a wide range of applications.*
24+
25+
## Available Pipelines:
26+
27+
| Pipeline | Tasks | Demo
28+
|---|---|:---:|
29+
| [StableDiffusionAdapterPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_adapter.py) | *Text-to-Image Generation with T2I-Adapter Conditioning* | -
30+
31+
## Usage example
32+
33+
In the following we give a simple example of how to use a *T2IAdapter* checkpoint with Diffusers for inference.
34+
The inference pipeline is the same for all pipelines:
35+
36+
1. Take an image and run it through a pre-conditioning processor to obtain *control image*.
37+
2. Run the pre-processed *control image* and *prompt* through the [`StableDiffusionAdapterPipeline`].
38+
39+
Let's have a look at a simple example using the [Color Adapter](https://huggingface.co/RzZ/sd-v1-4-adapter-color).
40+
41+
```python
42+
from diffusers.utils import load_image
43+
44+
image = load_image("https://huggingface.co/RzZ/sd-v1-4-adapter-color/resolve/main/color_ref.png")
45+
```
46+
47+
![img](https://huggingface.co/RzZ/sd-v1-4-adapter-color/resolve/main/color_ref.png)
48+
49+
50+
Then we can create our color palette by simply resize it to 8 by 8 pixels then scale it back to original size.
51+
52+
```python
53+
from PIL import Image
54+
55+
color_palette = image.resize((8, 8))
56+
color_palette = color_palette.resize((512, 512), resample=Image.Resampling.NEAREST)
57+
```
58+
59+
Let's take a look at the processed image.
60+
61+
![img](https://huggingface.co/RzZ/sd-v1-4-adapter-color/resolve/main/color_palette.png)
62+
63+
64+
After we having `color_palette` in hand, we can create the [`StableDiffusionAdapterPipeline`] with pretrained checkpoint.
65+
66+
```py
67+
import torch
68+
from diffusers import StableDiffusionAdapterPipeline, T2IAdapter
69+
70+
adapter = T2IAdapter.from_pretrained("RzZ/sd-v1-4-adapter-color")
71+
pipe = StableDiffusionAdapterPipeline.from_pretrained(
72+
"CompVis/stable-diffusion-v1-4",
73+
adapter=adapter,
74+
torch_dtype=torch.float16,
75+
)
76+
pipe.to("cuda")
77+
```
78+
79+
And finally we feed the data to the pipelien and wait for the result!
80+
81+
```py
82+
# fix the random seed, so you will get the same result as the example
83+
generator = torch.manual_seed(7)
84+
85+
out_image = pipe(
86+
["At night, glowing cubes in front of the beach"],
87+
image=[color_palette],
88+
generator=generator,
89+
).images[0]
90+
```
91+
92+
This should take only few seconds on GPU (depending on hardware). The output image then looks as follows:
93+
94+
![img](https://huggingface.co/RzZ/sd-v1-4-adapter-color/resolve/main/color_output.png)
95+
96+
97+
**Note**: To see how to run all other Adapter checkpoints, please have a look at [T2I-Adapter with Stable Diffusion 1.4](#t2i-adapter-with-stable-diffusion-1.4)
98+
99+
<!-- TODO: add space -->
100+
101+
## Available checkpoints
102+
103+
Adapter requires a *control image* in addition to the text-to-image *prompt*.
104+
Each pretrained model is trained using a different conditioning method that requires different images for conditioning the generated outputs. For example, Canny edge conditioning requires the control image to be the output of a Canny filter, while depth conditioning requires the control image to be a depth map. See the overview and image examples below to know more.
105+
106+
All official checkpoints can be found under the authors' namespace [TencentARC/T2I-Adapter](https://huggingface.co/TencentARC/T2I-Adapter/tree/main/models).
107+
108+
### T2I-Adapter with Stable Diffusion 1.4
109+
110+
| Model Name | Control Image Overview| Control Image Example | Generated Image Example |
111+
|---|---|---|---|
112+
|[RzZ/sd-v1.4-adapter-color](https://huggingface.co/RzZ/sd-v1-4-adapter-color/)<br/> *Trained with spatial color palette* | A image with 8x8 color palette.|<a href="https://huggingface.co/RzZ/sd-v1-4-adapter-color/resolve/main/sample_input.png"><img width="64" style="margin:0;padding:0;" src="https://huggingface.co/RzZ/sd-v1-4-adapter-color/resolve/main/sample_input.png"/></a>|<a href="https://huggingface.co/RzZ/sd-v1-4-adapter-color/resolve/main/sample_output.png"><img width="64" src="https://huggingface.co/RzZ/sd-v1-4-adapter-color/resolve/main/sample_output.png"/></a>|
113+
|[RzZ/sd-v1.4-adapter-canny](https://huggingface.co/RzZ/sd-v1-4-adapter-canny)<br/> *Trained with canny edge detection* | A monochrome image with white edges on a black background.|<a href="https://huggingface.co/RzZ/sd-v1-4-adapter-canny/resolve/main/sample_input.png"><img width="64" style="margin:0;padding:0;" src="https://huggingface.co/RzZ/sd-v1-4-adapter-canny/resolve/main/sample_input.png"/></a>|<a href="https://huggingface.co/RzZ/sd-v1-4-adapter-canny/resolve/main/sample_output.png"><img width="64" src="https://huggingface.co/RzZ/sd-v1-4-adapter-canny/resolve/main/sample_output.png"/></a>|
114+
|[RzZ/sd-v1.4-adapter-sketch](https://huggingface.co/RzZ/sd-v1-4-adapter-sketch)<br/> *Trained with [PidiNet](https://github.com/zhuoinoulu/pidinet) edge detection* | A hand-drawn monochrome image with white outlines on a black background.|<a href="https://huggingface.co/RzZ/sd-v1-4-adapter-sketch/resolve/main/sample_input.png"><img width="64" style="margin:0;padding:0;" src="https://huggingface.co/RzZ/sd-v1-4-adapter-sketch/resolve/main/sample_input.png"/></a>|<a href="https://huggingface.co/RzZ/sd-v1-4-adapter-sketch/resolve/main/sample_output.png"><img width="64" src="https://huggingface.co/RzZ/sd-v1-4-adapter-sketch/resolve/main/sample_output.png"/></a>|
115+
|[RzZ/sd-v1.4-adapter-depth](https://huggingface.co/RzZ/sd-v1-4-adapter-depth)<br/> *Trained with Midas depth estimation* | A grayscale image with black representing deep areas and white representing shallow areas.|<a href="https://huggingface.co/RzZ/sd-v1-4-adapter-depth/resolve/main/sample_input.png"><img width="64" src="https://huggingface.co/RzZ/sd-v1-4-adapter-depth/resolve/main/sample_input.png"/></a>|<a href="https://huggingface.co/RzZ/sd-v1-4-adapter-depth/resolve/main/sample_output.png"><img width="64" src="https://huggingface.co/RzZ/sd-v1-4-adapter-depth/resolve/main/sample_output.png"/></a>|
116+
|[RzZ/sd-v1.4-adapter-openpose](https://huggingface.co/RzZ/sd-v1-4-adapter-openpose)<br/> *Trained with OpenPose bone image* | A [OpenPose bone](https://github.com/CMU-Perceptual-Computing-Lab/openpose) image.|<a href="https://huggingface.co/RzZ/sd-v1-4-adapter-openpose/resolve/main/sample_input.png"><img width="64" src="https://huggingface.co/RzZ/sd-v1-4-adapter-openpose/resolve/main/sample_input.png"/></a>|<a href="https://huggingface.co/RzZ/sd-v1-4-adapter-openpose/resolve/main/sample_output.png"><img width="64" src="https://huggingface.co/RzZ/sd-v1-4-adapter-openpose/resolve/main/sample_output.png"/></a>|
117+
|[RzZ/sd-v1.4-adapter-keypose](https://huggingface.co/RzZ/sd-v1-4-adapter-keypose)<br/> *Trained with mmpose skeleton image* | A [mmpose skeleton](https://github.com/open-mmlab/mmpose) image.|<a href="https://huggingface.co/RzZ/sd-v1-4-adapter-keypose/resolve/main/sample_input.png"><img width="64" src="https://huggingface.co/RzZ/sd-v1-4-adapter-keypose/resolve/main/sample_input.png"/></a>|<a href="https://huggingface.co/RzZ/sd-v1-4-adapter-keypose/resolve/main/sample_output.png"><img width="64" src="https://huggingface.co/RzZ/sd-v1-4-adapter-keypose/resolve/main/sample_output.png"/></a>|
118+
|[RzZ/sd-v1.4-adapter-seg](https://huggingface.co/RzZ/sd-v1-4-adapter-seg)<br/>*Trained with semantic segmentation* | An [custom](https://github.com/TencentARC/T2I-Adapter/discussions/25) segmentation protocol image.|<a href="https://huggingface.co/RzZ/sd-v1-4-adapter-seg/resolve/main/sample_input.png"><img width="64" src="https://huggingface.co/RzZ/sd-v1-4-adapter-seg/resolve/main/sample_input.png"/></a>|<a href="https://huggingface.co/RzZ/sd-v1-4-adapter-seg/resolve/main/sample_output.png"><img width="64" src="https://huggingface.co/RzZ/sd-v1-4-adapter-seg/resolve/main/sample_output.png"/></a> |
119+
120+
## Mix and match multiple adapters
121+
122+
[`StableDiffusionAdapterPipeline`] also support using multiple type of *control image* at once with combination with [`MultiAdapter`].
123+
Here is a example of using keypose adapter for character posture control and depth adapter for outlining background.
124+
125+
Just like the previous example, we will first prepare the *control image* for inference. One big difference when using [`MultiAdapter`] is that the *control image* we will send to pipeline is
126+
combined from multiple images. In this example we stack two 3 channels RGB image(`cond_keypose`, `cond_depth`) together to create a 6 channels image tensor(`cond`).
127+
128+
```py
129+
import torch
130+
from PIL import Image
131+
from diffusers.utils import load_image
132+
133+
cond_keypose = load_image(
134+
"https://huggingface.co/RzZ/sd-v1-4-adapter-keypose-depth/resolve/main/sample_input_keypose.png"
135+
)
136+
cond_depth = load_image("https://huggingface.co/RzZ/sd-v1-4-adapter-keypose-depth/resolve/main/sample_input_depth.png")
137+
cond = [[cond_keypose, cond_depth]]
138+
139+
prompt = ["A man waling in an office room with nice view"]
140+
```
141+
142+
Two *control image* should look like follows:
143+
144+
![img](https://huggingface.co/RzZ/sd-v1-4-adapter-keypose-depth/resolve/main/sample_input_keypose.png)
145+
![img](https://huggingface.co/RzZ/sd-v1-4-adapter-keypose-depth/resolve/main/sample_input_depth.png)
146+
147+
148+
Now we can using `from_adapters` method combine keypose and depth adapter into one, then pass our newly created [`MultiAdapter`] to
149+
[`StableDiffusionAdapterPipeline`]. You can also play around the value of `adapter_conditioning_scale` to balance the control between adapters.
150+
151+
```py
152+
from diffusers import StableDiffusionAdapterPipeline, MultiAdapter
153+
154+
adapters = MultiAdapter(
155+
[
156+
T2IAdapter.from_pretrained("RzZ/sd-v1-4-adapter-keypose"),
157+
T2IAdapter.from_pretrained("RzZ/sd-v1-4-adapter-depth"),
158+
]
159+
)
160+
adapters = adapters.to(torch.float16)
161+
162+
pipe = StableDiffusionAdapterPipeline.from_pretrained(
163+
"CompVis/stable-diffusion-v1-4",
164+
torch_dtype=torch.float16,
165+
adapter=adapters,
166+
)
167+
168+
images = pipe(prompt, cond, adapter_conditioning_scale=[0.8, 0.8])
169+
```
170+
171+
After prompt and image is processed by pipeline we should get the result looks like:
172+
173+
![img](https://huggingface.co/RzZ/sd-v1-4-adapter-keypose-depth/resolve/main/sample_output.png)
174+
175+
176+
## T2I Adapter vs ControlNet
177+
178+
T2I-Adapter is similar to ControlNet. However, T2i-Adapter uses a smaller auxiliary network which is only run once for the entire diffusion process. T2I-Adapter performs slightly worse than ControlNet. However, T2I-Adapter is cheaper to run and is cheaper to run multiple auxiliary networks.
179+
180+
## StableDiffusionAdapterPipeline
181+
[[autodoc]] StableDiffusionAdapterPipeline
182+
- all
183+
- __call__
184+
- enable_attention_slicing
185+
- disable_attention_slicing
186+
- enable_vae_slicing
187+
- disable_vae_slicing
188+
- enable_xformers_memory_efficient_attention
189+
- disable_xformers_memory_efficient_attention

docs/source/en/index.mdx

+1
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ The library has three main components:
6666
| [score_sde_ve](./api/pipelines/score_sde_ve) | [Score-Based Generative Modeling through Stochastic Differential Equations](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
6767
| [score_sde_vp](./api/pipelines/score_sde_vp) | [Score-Based Generative Modeling through Stochastic Differential Equations](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
6868
| [semantic_stable_diffusion](./api/pipelines/semantic_stable_diffusion) | [Semantic Guidance](https://arxiv.org/abs/2301.12247) | Text-Guided Generation |
69+
| [stable_diffusion_adapter](./api/pipelines/stable_diffusion/adapter) | [**T2I-Adapter**](https://arxiv.org/abs/2302.08453) | Image-to-Image Text-Guided Generation | -
6970
| [stable_diffusion_text2img](./api/pipelines/stable_diffusion/text2img) | [Stable Diffusion](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation |
7071
| [stable_diffusion_img2img](./api/pipelines/stable_diffusion/img2img) | [Stable Diffusion](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation |
7172
| [stable_diffusion_inpaint](./api/pipelines/stable_diffusion/inpaint) | [Stable Diffusion](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting |

docs/source/en/using-diffusers/controlling_generation.mdx

+11
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ Unless otherwise mentioned, these are techniques that work with existing models
3737
9. [Textual Inversion](#textual-inversion)
3838
10. [ControlNet](#controlnet)
3939
11. [Prompt Weighting](#prompt-weighting)
40+
12. [T2I-Adapter](#t2i-adapter)
4041

4142
## Instruct Pix2Pix
4243

@@ -165,3 +166,13 @@ Prompt weighting is a simple technique that puts more attention weight on certai
165166
input.
166167

167168
For a more in-detail explanation and examples, see [here](../using-diffusers/weighted_prompts).
169+
170+
## T2I-Adapter
171+
172+
[Paper](https://arxiv.org/abs/2302.08453)
173+
174+
[T2I-Adapter](../api/pipelines/stable_diffusion/adapter) is an auxiliary network which adds an extra condition.
175+
There are 8 canonical pre-trained adapters trained on different conditionings such as edge detection, sketch,
176+
depth maps, and semantic segmentations.
177+
178+
See [here](../api/pipelines/stable_diffusion/adapter) for more information on how to use it.

0 commit comments

Comments
 (0)