Skip to content

Commit 7c02ee1

Browse files
williambermanHimariO
authored andcommitted
t2i pipeline (huggingface#3932)
* Quick implementation of t2i-adapter Load adapter module with from_pretrained Prototyping generalized adapter framework Writeup doc string for sideload framework(WIP) + some minor update on implementation Update adapter models Remove old adapter optional args in UNet Add StableDiffusionAdapterPipeline unit test Handle cpu offload in StableDiffusionAdapterPipeline Auto correct coding style Update model repo name to "RzZ/sd-v1-4-adapter-pipeline" Refactor MultiAdapter to better compatible with config system Export MultiAdapter Create pipeline document template from controlnet Create dummy objects Supproting new AdapterLight model Fix StableDiffusionAdapterPipeline common pipeline test [WIP] Update adapter pipeline document Handle num_inference_steps in StableDiffusionAdapterPipeline Update definition of Adapter "channels_in" Update documents Apply code style Fix doc typo and merge error Update doc string and example Quality of life improvement Remove redundant code and file from prototyping Remove unused pageage Remove comments Fix title Fix typo Add conditioning scale arg Bring back old implmentation Offload sideload Add supply info on document Update src/diffusers/models/adapter.py Co-authored-by: Will Berman <[email protected]> Update MultiAdapter constructor Swap out custom checkpoint and update pipeline constructor Update docment Apply suggestions from code review Co-authored-by: Will Berman <[email protected]> Correcting style Following single-file policy Update auto size in image preprocess func Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_adapter.py Co-authored-by: Will Berman <[email protected]> fix copies Update adapter pipeline behavior Add adapter_conditioning_scale doc string Add the missing doc string Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> Fix few bugs from suggestion Handle L-mode PIL image as control image Rename to differentiate adapter resblock Update src/diffusers/models/adapter.py Co-authored-by: Sayak Paul <[email protected]> Fix typo Update adapter parameter name Update test case and code style Fix copies Fix typo Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_adapter.py Co-authored-by: Will Berman <[email protected]> Update Adapter class name Add checkpoint converting script Fix style Fix-copies Remove dev script Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> Updates for parameter rename Fix convert_adapter remove main fix diff more refactoring more more small fixes refactor tests more slow tests more tests Update docs/source/en/api/pipelines/overview.mdx Co-authored-by: Sayak Paul <[email protected]> add community contributor to docs Update docs/source/en/api/pipelines/stable_diffusion/adapter.mdx Co-authored-by: Sayak Paul <[email protected]> Update docs/source/en/api/pipelines/stable_diffusion/adapter.mdx Co-authored-by: Sayak Paul <[email protected]> Update docs/source/en/api/pipelines/stable_diffusion/adapter.mdx Co-authored-by: Sayak Paul <[email protected]> Update docs/source/en/api/pipelines/stable_diffusion/adapter.mdx Co-authored-by: Sayak Paul <[email protected]> Update docs/source/en/api/pipelines/stable_diffusion/adapter.mdx Co-authored-by: Sayak Paul <[email protected]> fix remove from_adapters license paper link docs more url fixes more docs fix fixes fix fix * fix sample inplace add * additional_kwargs -> additional_residuals * move t2i adapter pipeline to own module * preprocess -> _preprocess_adapter_image * add TencentArc to license * fix example code links * add image converter and fix example doc string * fix links * clearer additional residual application --------- Co-authored-by: HimariO <[email protected]>
1 parent c1c9b1c commit 7c02ee1

18 files changed

+1989
-9
lines changed

docs/source/en/_toctree.yml

+4-2
Original file line numberDiff line numberDiff line change
@@ -196,12 +196,12 @@
196196
title: DDIM
197197
- local: api/pipelines/ddpm
198198
title: DDPM
199+
- local: api/pipelines/deepfloyd_if
200+
title: DeepFloyd IF
199201
- local: api/pipelines/diffedit
200202
title: DiffEdit
201203
- local: api/pipelines/dit
202204
title: DiT
203-
- local: api/pipelines/deepfloyd_if
204-
title: DeepFloyd IF
205205
- local: api/pipelines/pix2pix
206206
title: InstructPix2Pix
207207
- local: api/pipelines/kandinsky
@@ -257,6 +257,8 @@
257257
title: Super-Resolution
258258
- local: api/pipelines/stable_diffusion/ldm3d_diffusion
259259
title: LDM3D Text-to-(RGB, Depth)
260+
- local: api/pipelines/stable_diffusion/adapter
261+
title: Stable Diffusion T2I-adapter
260262
title: Stable Diffusion
261263
- local: api/pipelines/stable_unclip
262264
title: Stable unCLIP

docs/source/en/api/pipelines/overview.mdx

+1
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ available a colab notebook to directly try them out.
6666
| [score_sde_ve](./score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
6767
| [score_sde_vp](./score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
6868
| [semantic_stable_diffusion](./semantic_stable_diffusion) | [**SEGA: Instructing Diffusion using Semantic Dimensions**](https://arxiv.org/abs/2301.12247) | Text-to-Image Generation |
69+
| [stable_diffusion_adapter](./stable_diffusion/adapter) | [**T2I-Adapter**](https://arxiv.org/abs/2302.08453) | Image-to-Image Text-Guided Generation with Adapters | -
6970
| [stable_diffusion_text2img](./stable_diffusion/text2img) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
7071
| [stable_diffusion_img2img](./stable_diffusion/img2img) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
7172
| [stable_diffusion_inpaint](./stable_diffusion/inpaint) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
13+
# Text-to-Image Generation with Adapter Conditioning
14+
15+
## Overview
16+
17+
[T2I-Adapter: Learning Adapters to Dig out More Controllable Ability for Text-to-Image Diffusion Models](https://arxiv.org/abs/2302.08453) by Chong Mou, Xintao Wang, Liangbin Xie, Jian Zhang, Zhongang Qi, Ying Shan, Xiaohu Qie.
18+
19+
Using the pretrained models we can provide control images (for example, a depth map) to control Stable Diffusion text-to-image generation so that it follows the structure of the depth image and fills in the details.
20+
21+
The abstract of the paper is the following:
22+
23+
*The incredible generative ability of large-scale text-to-image (T2I) models has demonstrated strong power of learning complex structures and meaningful semantics. However, relying solely on text prompts cannot fully take advantage of the knowledge learned by the model, especially when flexible and accurate structure control is needed. In this paper, we aim to ``dig out" the capabilities that T2I models have implicitly learned, and then explicitly use them to control the generation more granularly. Specifically, we propose to learn simple and small T2I-Adapters to align internal knowledge in T2I models with external control signals, while freezing the original large T2I models. In this way, we can train various adapters according to different conditions, and achieve rich control and editing effects. Further, the proposed T2I-Adapters have attractive properties of practical value, such as composability and generalization ability. Extensive experiments demonstrate that our T2I-Adapter has promising generation quality and a wide range of applications.*
24+
25+
This model was contributed by the community contributor [HimariO](https://github.com/HimariO) ❤️ .
26+
27+
## Available Pipelines:
28+
29+
| Pipeline | Tasks | Demo
30+
|---|---|:---:|
31+
| [StableDiffusionAdapterPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_adapter.py) | *Text-to-Image Generation with T2I-Adapter Conditioning* | -
32+
33+
## Usage example
34+
35+
In the following we give a simple example of how to use a *T2IAdapter* checkpoint with Diffusers for inference.
36+
All adapters use the same pipeline.
37+
38+
1. Images are first converted into the appropriate *control image* format.
39+
2. The *control image* and *prompt* are passed to the [`StableDiffusionAdapterPipeline`].
40+
41+
Let's have a look at a simple example using the [Color Adapter](https://huggingface.co/TencentARC/t2iadapter_color_sd14v1).
42+
43+
```python
44+
from diffusers.utils import load_image
45+
46+
image = load_image("https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/color_ref.png")
47+
```
48+
49+
![img](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/color_ref.png)
50+
51+
52+
Then we can create our color palette by simply resizing it to 8 by 8 pixels and then scaling it back to original size.
53+
54+
```python
55+
from PIL import Image
56+
57+
color_palette = image.resize((8, 8))
58+
color_palette = color_palette.resize((512, 512), resample=Image.Resampling.NEAREST)
59+
```
60+
61+
Let's take a look at the processed image.
62+
63+
![img](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/color_palette.png)
64+
65+
66+
Next, create the adapter pipeline
67+
68+
```py
69+
import torch
70+
from diffusers import StableDiffusionAdapterPipeline, T2IAdapter
71+
72+
adapter = T2IAdapter.from_pretrained("TencentARC/t2iadapter_color_sd14v1")
73+
pipe = StableDiffusionAdapterPipeline.from_pretrained(
74+
"CompVis/stable-diffusion-v1-4",
75+
adapter=adapter,
76+
torch_dtype=torch.float16,
77+
)
78+
pipe.to("cuda")
79+
```
80+
81+
Finally, pass the prompt and control image to the pipeline
82+
83+
```py
84+
# fix the random seed, so you will get the same result as the example
85+
generator = torch.manual_seed(7)
86+
87+
out_image = pipe(
88+
"At night, glowing cubes in front of the beach",
89+
image=color_palette,
90+
generator=generator,
91+
).images[0]
92+
```
93+
94+
![img](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/color_output.png)
95+
96+
97+
## Available checkpoints
98+
99+
Non-diffusers checkpoints can be found under [TencentARC/T2I-Adapter](https://huggingface.co/TencentARC/T2I-Adapter/tree/main/models).
100+
101+
### T2I-Adapter with Stable Diffusion 1.4
102+
103+
| Model Name | Control Image Overview| Control Image Example | Generated Image Example |
104+
|---|---|---|---|
105+
|[TencentARC/t2iadapter_color_sd14v1](https://huggingface.co/TencentARC/t2iadapter_color_sd14v1)<br/> *Trained with spatial color palette* | A image with 8x8 color palette.|<a href="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/color_sample_input.png"><img width="64" style="margin:0;padding:0;" src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/color_sample_input.png"/></a>|<a href="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/color_sample_output.png"><img width="64" src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/color_sample_output.png"/></a>|
106+
|[TencentARC/t2iadapter_canny_sd14v1](https://huggingface.co/TencentARC/t2iadapter_canny_sd14v1)<br/> *Trained with canny edge detection* | A monochrome image with white edges on a black background.|<a href="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/canny_sample_input.png"><img width="64" style="margin:0;padding:0;" src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/canny_sample_input.png"/></a>|<a href="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/canny_sample_output.png"><img width="64" src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/canny_sample_output.png"/></a>|
107+
|[TencentARC/t2iadapter_sketch_sd14v1](https://huggingface.co/TencentARC/t2iadapter_sketch_sd14v1)<br/> *Trained with [PidiNet](https://github.com/zhuoinoulu/pidinet) edge detection* | A hand-drawn monochrome image with white outlines on a black background.|<a href="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/sketch_sample_input.png"><img width="64" style="margin:0;padding:0;" src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/sketch_sample_input.png"/></a>|<a href="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/sketch_sample_output.png"><img width="64" src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/sketch_sample_output.png"/></a>|
108+
|[TencentARC/t2iadapter_depth_sd14v1](https://huggingface.co/TencentARC/t2iadapter_depth_sd14v1)<br/> *Trained with Midas depth estimation* | A grayscale image with black representing deep areas and white representing shallow areas.|<a href="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/depth_sample_input.png"><img width="64" src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/depth_sample_input.png"/></a>|<a href="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/depth_sample_output.png"><img width="64" src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/depth_sample_output.png"/></a>|
109+
|[TencentARC/t2iadapter_openpose_sd14v1](https://huggingface.co/TencentARC/t2iadapter_openpose_sd14v1)<br/> *Trained with OpenPose bone image* | A [OpenPose bone](https://github.com/CMU-Perceptual-Computing-Lab/openpose) image.|<a href="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/openpose_sample_input.png"><img width="64" src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/openpose_sample_input.png"/></a>|<a href="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/openpose_sample_output.png"><img width="64" src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/openpose_sample_output.png"/></a>|
110+
|[TencentARC/t2iadapter_keypose_sd14v1](https://huggingface.co/TencentARC/t2iadapter_keypose_sd14v1)<br/> *Trained with mmpose skeleton image* | A [mmpose skeleton](https://github.com/open-mmlab/mmpose) image.|<a href="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/keypose_sample_input.png"><img width="64" src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/keypose_sample_input.png"/></a>|<a href="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/keypose_sample_output.png"><img width="64" src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/keypose_sample_output.png"/></a>|
111+
|[TencentARC/t2iadapter_seg_sd14v1](https://huggingface.co/TencentARC/t2iadapter_seg_sd14v1)<br/>*Trained with semantic segmentation* | An [custom](https://github.com/TencentARC/T2I-Adapter/discussions/25) segmentation protocol image.|<a href="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/seg_sample_input.png"><img width="64" src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/seg_sample_input.png"/></a>|<a href="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/seg_sample_output.png"><img width="64" src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/seg_sample_output.png"/></a> |
112+
|[TencentARC/t2iadapter_canny_sd15v2](https://huggingface.co/TencentARC/t2iadapter_canny_sd15v2)||
113+
|[TencentARC/t2iadapter_depth_sd15v2](https://huggingface.co/TencentARC/t2iadapter_depth_sd15v2)||
114+
|[TencentARC/t2iadapter_sketch_sd15v2](https://huggingface.co/TencentARC/t2iadapter_sketch_sd15v2)||
115+
|[TencentARC/t2iadapter_zoedepth_sd15v1](https://huggingface.co/TencentARC/t2iadapter_zoedepth_sd15v1)||
116+
117+
## Combining multiple adapters
118+
119+
[`MultiAdapter`] can be used for applying multiple conditionings at once.
120+
121+
Here we use the keypose adapter for the character posture and the depth adapter for creating the scene.
122+
123+
```py
124+
import torch
125+
from PIL import Image
126+
from diffusers.utils import load_image
127+
128+
cond_keypose = load_image(
129+
"https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/keypose_sample_input.png"
130+
)
131+
cond_depth = load_image(
132+
"https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/depth_sample_input.png"
133+
)
134+
cond = [[cond_keypose, cond_depth]]
135+
136+
prompt = ["A man walking in an office room with a nice view"]
137+
```
138+
139+
The two control images look as such:
140+
141+
![img](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/keypose_sample_input.png)
142+
![img](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/depth_sample_input.png)
143+
144+
145+
`MultiAdapter` combines keypose and depth adapters.
146+
147+
`adapter_conditioning_scale` balances the relative influence of the different adapters.
148+
149+
```py
150+
from diffusers import StableDiffusionAdapterPipeline, MultiAdapter
151+
152+
adapters = MultiAdapter(
153+
[
154+
T2IAdapter.from_pretrained("TencentARC/t2iadapter_keypose_sd14v1"),
155+
T2IAdapter.from_pretrained("TencentARC/t2iadapter_depth_sd14v1"),
156+
]
157+
)
158+
adapters = adapters.to(torch.float16)
159+
160+
pipe = StableDiffusionAdapterPipeline.from_pretrained(
161+
"CompVis/stable-diffusion-v1-4",
162+
torch_dtype=torch.float16,
163+
adapter=adapters,
164+
)
165+
166+
images = pipe(prompt, cond, adapter_conditioning_scale=[0.8, 0.8])
167+
```
168+
169+
![img](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/keypose_depth_sample_output.png)
170+
171+
172+
## T2I Adapter vs ControlNet
173+
174+
T2I-Adapter is similar to [ControlNet](https://huggingface.co/docs/diffusers/main/en/api/pipelines/controlnet).
175+
T2i-Adapter uses a smaller auxiliary network which is only run once for the entire diffusion process.
176+
However, T2I-Adapter performs slightly worse than ControlNet.
177+
178+
## StableDiffusionAdapterPipeline
179+
[[autodoc]] StableDiffusionAdapterPipeline
180+
- all
181+
- __call__
182+
- enable_attention_slicing
183+
- disable_attention_slicing
184+
- enable_vae_slicing
185+
- disable_vae_slicing
186+
- enable_xformers_memory_efficient_attention
187+
- disable_xformers_memory_efficient_attention

docs/source/en/index.mdx

+1
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ The library has three main components:
6969
| [score_sde_ve](./api/pipelines/score_sde_ve) | [Score-Based Generative Modeling through Stochastic Differential Equations](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
7070
| [score_sde_vp](./api/pipelines/score_sde_vp) | [Score-Based Generative Modeling through Stochastic Differential Equations](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
7171
| [semantic_stable_diffusion](./api/pipelines/semantic_stable_diffusion) | [Semantic Guidance](https://arxiv.org/abs/2301.12247) | Text-Guided Generation |
72+
| [stable_diffusion_adapter](./api/pipelines/stable_diffusion/adapter) | [**T2I-Adapter**](https://arxiv.org/abs/2302.08453) | Image-to-Image Text-Guided Generation | -
7273
| [stable_diffusion_text2img](./api/pipelines/stable_diffusion/text2img) | [Stable Diffusion](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation |
7374
| [stable_diffusion_img2img](./api/pipelines/stable_diffusion/img2img) | [Stable Diffusion](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation |
7475
| [stable_diffusion_inpaint](./api/pipelines/stable_diffusion/inpaint) | [Stable Diffusion](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting |

docs/source/en/using-diffusers/controlling_generation.mdx

+11-1
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ For convenience, we provide a table to denote which methods are inference-only a
5959
| [Custom Diffusion](#custom-diffusion) | | | |
6060
| [Model Editing](#model-editing) | | | |
6161
| [DiffEdit](#diffedit) | | | |
62+
| [T2I-Adapter](#t2i-adapter) | | | |
6263

6364
## Instruct Pix2Pix
6465

@@ -215,4 +216,13 @@ To know more details, check out the [official doc](../api/pipelines/stable_diffu
215216
[DiffEdit](../api/pipelines/stable_diffusion/diffedit) allows for semantic editing of input images along with
216217
input prompts while preserving the original input images as much as possible.
217218

218-
To know more details, check out the [official doc](../api/pipelines/stable_diffusion/model_editing).
219+
To know more details, check out the [official doc](../api/pipelines/stable_diffusion/model_editing).
220+
## T2I-Adapter
221+
222+
[Paper](https://arxiv.org/abs/2302.08453)
223+
224+
[T2I-Adapter](../api/pipelines/stable_diffusion/adapter) is an auxiliary network which adds an extra condition.
225+
There are 8 canonical pre-trained adapters trained on different conditionings such as edge detection, sketch,
226+
depth maps, and semantic segmentations.
227+
228+
See [here](../api/pipelines/stable_diffusion/adapter) for more information on how to use it.

0 commit comments

Comments
 (0)