Skip to content

Commit 21e39a5

Browse files
takuma104williambermanpcuencadqueuesayakpaul
authored and
Jimmy
committed
Add a ControlNet model & pipeline (huggingface#2407)
* add scaffold - copied convert_controlnet_to_diffusers.py from convert_original_stable_diffusion_to_diffusers.py * Add support to load ControlNet (WIP) - this makes Missking Key error on ControlNetModel * Update to convert ControlNet without error msg - init impl for StableDiffusionControlNetPipeline - init impl for ControlNetModel * cleanup of commented out * split create_controlnet_diffusers_config() from create_unet_diffusers_config() - add config: hint_channels * Add input_hint_block, input_zero_conv and middle_block_out - this makes missing key error on loading model * add unet_2d_blocks_controlnet.py - copied from unet_2d_blocks.py as impl CrossAttnDownBlock2D,DownBlock2D - this makes missing key error on loading model * Add loading for input_hint_block, zero_convs and middle_block_out - this makes no error message on model loading * Copy from UNet2DConditionalModel except __init__ * Add ultra primitive test for ControlNetModel inference * Support ControlNetModel inference - without exceptions * copy forward() from UNet2DConditionModel * Impl ControlledUNet2DConditionModel inference - test_controlled_unet_inference passed * Frozen weight & biases for training * Minimized version of ControlNet/ControlledUnet - test_modules_controllnet.py passed * make style * Add support model loading for minimized ver * Remove all previous version files * from_pretrained and inference test passed * copied from pipeline_stable_diffusion.py except `__init__()` * Impl pipeline, pixel match test (almost) passed. * make style * make fix-copies * Fix to add import ControlNet blocks for `make fix-copies` * Remove einops dependency * Support np.ndarray, PIL.Image for controlnet_hint * set default config file as lllyasviel's * Add support grayscale (hw) numpy array * Add and update docstrings * add control_net.mdx * add control_net.mdx to toctree * Update copyright year * Fix to add PIL.Image RGB->BGR conversion - thanks @Mystfit * make fix-copies * add basic fast test for controlnet * add slow test for controlnet/unet * Ignore down/up_block len check on ControlNet * add a copy from test_stable_diffusion.py * Accept controlnet_hint is None * merge pipeline_stable_diffusion.py diff * Update class name to SDControlNetPipeline * make style * Baseline fast test almost passed (w long desc) * still needs investigate. Following didn't passed descriped in TODO comment: - test_stable_diffusion_long_prompt - test_stable_diffusion_no_safety_checker Following didn't passed same as stable_diffusion_pipeline: - test_attention_slicing_forward_pass - test_inference_batch_single_identical - test_xformers_attention_forwardGenerator_pass these seems come from calc accuracy. * Add note comment related vae_scale_factor * add test_stable_diffusion_controlnet_ddim * add assertion for vae_scale_factor != 8 * slow test of pipeline almost passed Failed: test_stable_diffusion_pipeline_with_model_offloading - ImportError: `enable_model_offload` requires `accelerate v0.17.0` or higher but currently latest version == 0.16.0 * test_stable_diffusion_long_prompt passed * test_stable_diffusion_no_safety_checker passed - due to its model size, move to slow test * remove PoC test files * fix num_of_image, prompt length issue add add test * add support List[PIL.Image] for controlnet_hint * wip * all slow test passed * make style * update for slow test * RGB(PIL)->BGR(ctrlnet) conversion * fixes * remove manual num_images_per_prompt test * add document * add `image` argument docstring * make style * Add line to correct conversion * add controlnet_conditioning_scale (aka control_scales strength) * rgb channel ordering by default * image batching logic * Add control image descriptions for each checkpoint * Only save controlnet model in conversion script * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py typo Co-authored-by: Pedro Cuenca <[email protected]> * Update docs/source/en/api/pipelines/stable_diffusion/control_net.mdx Co-authored-by: Pedro Cuenca <[email protected]> * Update docs/source/en/api/pipelines/stable_diffusion/control_net.mdx Co-authored-by: Pedro Cuenca <[email protected]> * Update docs/source/en/api/pipelines/stable_diffusion/control_net.mdx Co-authored-by: Pedro Cuenca <[email protected]> * Update docs/source/en/api/pipelines/stable_diffusion/control_net.mdx Co-authored-by: Pedro Cuenca <[email protected]> * Update docs/source/en/api/pipelines/stable_diffusion/control_net.mdx Co-authored-by: Pedro Cuenca <[email protected]> * Update docs/source/en/api/pipelines/stable_diffusion/control_net.mdx Co-authored-by: Pedro Cuenca <[email protected]> * Update docs/source/en/api/pipelines/stable_diffusion/control_net.mdx Co-authored-by: Pedro Cuenca <[email protected]> * Update docs/source/en/api/pipelines/stable_diffusion/control_net.mdx Co-authored-by: Pedro Cuenca <[email protected]> * Update docs/source/en/api/pipelines/stable_diffusion/control_net.mdx Co-authored-by: Pedro Cuenca <[email protected]> * add gerated image example * a depth mask -> a depth map * rename control_net.mdx to controlnet.mdx * fix toc title * add ControlNet abstruct and link * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py Co-authored-by: dqueue <[email protected]> * remove controlnet constructor arguments re: @patrickvonplaten * [integration tests] test canny * test_canny fixes * [integration tests] test_depth * [integration tests] test_hed * [integration tests] test_mlsd * add channel order config to controlnet * [integration tests] test normal * [integration tests] test_openpose test_scribble * change height and width to default to conditioning image * [integration tests] test seg * style * test_depth fix * [integration tests] size fixes * [integration tests] cpu offloading * style * generalize controlnet embedding * fix conversion script * Update docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx Co-authored-by: Sayak Paul <[email protected]> * Update docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx Co-authored-by: Sayak Paul <[email protected]> * Update docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx Co-authored-by: Sayak Paul <[email protected]> * Update docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx Co-authored-by: Sayak Paul <[email protected]> * Style adapted to the documentation of pix2pix * merge main by hand * style * [docs] controlling generation doc nits * correct some things * add: controlnetmodel to autodoc. * finish docs * finish * finish 2 * correct images * finish controlnet * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> * uP * upload model * up * up --------- Co-authored-by: William Berman <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]> Co-authored-by: dqueue <[email protected]> Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: Patrick von Platen <[email protected]>
1 parent 0b14107 commit 21e39a5

20 files changed

+2100
-27
lines changed

docs/source/en/_toctree.yml

+2
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,8 @@
165165
title: Self-Attention Guidance
166166
- local: api/pipelines/stable_diffusion/panorama
167167
title: MultiDiffusion Panorama
168+
- local: api/pipelines/stable_diffusion/controlnet
169+
title: Text-to-Image Generation with ControlNet Conditioning
168170
title: Stable Diffusion
169171
- local: api/pipelines/stable_diffusion_2
170172
title: Stable Diffusion 2

docs/source/en/api/models.mdx

+6
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,12 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module
6464
## PriorTransformerOutput
6565
[[autodoc]] models.prior_transformer.PriorTransformerOutput
6666

67+
## ControlNetOutput
68+
[[autodoc]] models.controlnet.ControlNetOutput
69+
70+
## ControlNetModel
71+
[[autodoc]] ControlNetModel
72+
6773
## FlaxModelMixin
6874
[[autodoc]] FlaxModelMixin
6975

docs/source/en/api/pipelines/overview.mdx

+1
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ available a colab notebook to directly try them out.
4646
|---|---|:---:|:---:|
4747
| [alt_diffusion](./alt_diffusion) | [**AltDiffusion**](https://arxiv.org/abs/2211.06679) | Image-to-Image Text-Guided Generation | -
4848
| [audio_diffusion](./audio_diffusion) | [**Audio Diffusion**](https://github.com/teticio/audio_diffusion.git) | Unconditional Audio Generation |
49+
| [controlnet](./api/pipelines/stable_diffusion/controlnet) | [**ControlNet with Stable Diffusion**](https://arxiv.org/abs/2302.05543) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1AiR7Q-sBqO88NCyswpfiuwXZc7DfMyKA?usp=sharing)
4950
| [cycle_diffusion](./cycle_diffusion) | [**Cycle Diffusion**](https://arxiv.org/abs/2210.05559) | Image-to-Image Text-Guided Generation |
5051
| [dance_diffusion](./dance_diffusion) | [**Dance Diffusion**](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation |
5152
| [ddpm](./ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation |
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
13+
# Text-to-Image Generation with ControlNet Conditioning
14+
15+
## Overview
16+
17+
[Adding Conditional Control to Text-to-Image Diffusion Models](https://arxiv.org/abs/2302.05543) by Lvmin Zhang and Maneesh Agrawala.
18+
19+
Using the pretrained models we can provide control images (for example, a depth map) to control Stable Diffusion text-to-image generation so that it follows the structure of the depth image and fills in the details.
20+
21+
The abstract of the paper is the following:
22+
23+
*We present a neural network structure, ControlNet, to control pretrained large diffusion models to support additional input conditions. The ControlNet learns task-specific conditions in an end-to-end way, and the learning is robust even when the training dataset is small (< 50k). Moreover, training a ControlNet is as fast as fine-tuning a diffusion model, and the model can be trained on a personal devices. Alternatively, if powerful computation clusters are available, the model can scale to large amounts (millions to billions) of data. We report that large diffusion models like Stable Diffusion can be augmented with ControlNets to enable conditional inputs like edge maps, segmentation maps, keypoints, etc. This may enrich the methods to control large diffusion models and further facilitate related applications.*
24+
25+
This model was contributed by the amazing community contributor [takuma104](https://huggingface.co/takuma104) ❤️ .
26+
27+
Resources:
28+
29+
* [Paper](https://arxiv.org/abs/2302.05543)
30+
* [Original Code](https://github.com/lllyasviel/ControlNet)
31+
32+
## Available Pipelines:
33+
34+
| Pipeline | Tasks | Demo
35+
|---|---|:---:|
36+
| [StableDiffusionControlNetPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py) | *Text-to-Image Generation with ControlNet Conditioning* | [Colab Example](https://colab.research.google.com/drive/1AiR7Q-sBqO88NCyswpfiuwXZc7DfMyKA?usp=sharing) |
37+
38+
## Usage example
39+
40+
In the following we give a simple example of how to use a *ControlNet* checkpoint with Diffusers for inference.
41+
The inference pipeline is the same for all pipelines:
42+
43+
* 1. Take an image and run it through a pre-conditioning processor.
44+
* 2. Run the pre-processed image through the [`StableDiffusionControlNetPipeline`].
45+
46+
Let's have a look at a simple example using the [Canny Edge ControlNet](https://huggingface.co/lllyasviel/sd-controlnet-canny).
47+
48+
```python
49+
from diffusers import StableDiffusionControlNetPipeline
50+
from diffusers.utils import load_image
51+
52+
# Let's load the popular vermeer image
53+
image = load_image(
54+
"https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png"
55+
)
56+
```
57+
58+
![img](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png)
59+
60+
Next, we process the image to get the canny image. This is step *1.* - running the pre-conditioning processor. The pre-conditioning processor is different for every ControlNet. Please see the model cards of the [official checkpoints](#controlnet-with-stable-diffusion-1.5) for more information about other models.
61+
62+
First, we need to install opencv:
63+
64+
```
65+
pip install opencv-contrib-python
66+
```
67+
68+
Then we can retrieve the canny edges of the image.
69+
70+
```python
71+
import cv2
72+
from PIL import Image
73+
import numpy as np
74+
75+
image = np.array(image)
76+
77+
low_threshold = 100
78+
high_threshold = 200
79+
80+
image = cv2.Canny(image, low_threshold, high_threshold)
81+
image = image[:, :, None]
82+
image = np.concatenate([image, image, image], axis=2)
83+
canny_image = Image.fromarray(image)
84+
```
85+
86+
Let's take a look at the processed image.
87+
88+
![img](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/vermeer_canny_edged.png)
89+
90+
Now, we load the official [Stable Diffusion 1.5 Model](runwayml/stable-diffusion-v1-5) as well as the ControlNet for canny edges.
91+
92+
```py
93+
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
94+
import torch
95+
96+
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
97+
pipe = StableDiffusionControlNetPipeline.from_pretrained(
98+
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
99+
)
100+
```
101+
102+
To speed-up things and reduce memory, let's enable model offloading and use the fast [`UniPCMultistepScheduler`].
103+
104+
```py
105+
from diffusers import UniPCMultistepScheduler
106+
107+
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
108+
109+
# this command loads the individual model components on GPU on-demand.
110+
pipe.enable_model_cpu_offload()
111+
```
112+
113+
Finally, we can run the pipeline:
114+
115+
```py
116+
generator = torch.manual_seed(0)
117+
118+
out_image = pipe(
119+
"disco dancer with colorful lights", num_inference_steps=20, generator=generator, image=canny_image
120+
).images[0]
121+
```
122+
123+
This should take only around 3-4 seconds on GPU (depending on hardware). The output image then looks as follows:
124+
125+
![img](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/vermeer_disco_dancing.png)
126+
127+
128+
**Note**: To see how to run all other ControlNet checkpoints, please have a look at [ControlNet with Stable Diffusion 1.5](#controlnet-with-stable-diffusion-1.5)
129+
130+
<!-- TODO: add space -->
131+
132+
## Available checkpoints
133+
134+
ControlNet requires a *control image* in addition to the text-to-image *prompt*.
135+
Each pretrained model is trained using a different conditioning method that requires different images for conditioning the generated outputs. For example, Canny edge conditioning requires the control image to be the output of a Canny filter, while depth conditioning requires the control image to be a depth map. See the overview and image examples below to know more.
136+
137+
All checkpoints can be found under the authors' namespace [lllyasviel](https://huggingface.co/lllyasviel).
138+
139+
### ControlNet with Stable Diffusion 1.5
140+
141+
| Model Name | Control Image Overview| Control Image Example | Generated Image Example |
142+
|---|---|---|---|
143+
|[lllyasviel/sd-controlnet-canny](https://huggingface.co/lllyasviel/sd-controlnet-canny)<br/> *Trained with canny edge detection* | A monochrome image with white edges on a black background.|<a href="https://huggingface.co/takuma104/controlnet_dev/blob/main/gen_compare/control_images/converted/control_bird_canny.png"><img width="64" style="margin:0;padding:0;" src="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/control_images/converted/control_bird_canny.png"/></a>|<a href="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/output_images/diffusers/output_bird_canny_1.png"><img width="64" src="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/output_images/diffusers/output_bird_canny_1.png"/></a>|
144+
|[lllyasviel/sd-controlnet-depth](https://huggingface.co/lllyasviel/sd-controlnet-depth)<br/> *Trained with Midas depth estimation* |A grayscale image with black representing deep areas and white representing shallow areas.|<a href="https://huggingface.co/takuma104/controlnet_dev/blob/main/gen_compare/control_images/converted/control_vermeer_depth.png"><img width="64" src="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/control_images/converted/control_vermeer_depth.png"/></a>|<a href="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/output_images/diffusers/output_vermeer_depth_2.png"><img width="64" src="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/output_images/diffusers/output_vermeer_depth_2.png"/></a>|
145+
|[lllyasviel/sd-controlnet-hed](https://huggingface.co/lllyasviel/sd-controlnet-hed)<br/> *Trained with HED edge detection (soft edge)* |A monochrome image with white soft edges on a black background.|<a href="https://huggingface.co/takuma104/controlnet_dev/blob/main/gen_compare/control_images/converted/control_bird_hed.png"><img width="64" src="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/control_images/converted/control_bird_hed.png"/></a>|<a href="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/output_images/diffusers/output_bird_hed_1.png"><img width="64" src="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/output_images/diffusers/output_bird_hed_1.png"/></a> |
146+
|[lllyasviel/sd-controlnet-mlsd](https://huggingface.co/lllyasviel/sd-controlnet-mlsd)<br/> *Trained with M-LSD line detection* |A monochrome image composed only of white straight lines on a black background.|<a href="https://huggingface.co/takuma104/controlnet_dev/blob/main/gen_compare/control_images/converted/control_room_mlsd.png"><img width="64" src="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/control_images/converted/control_room_mlsd.png"/></a>|<a href="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/output_images/diffusers/output_room_mlsd_0.png"><img width="64" src="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/output_images/diffusers/output_room_mlsd_0.png"/></a>|
147+
|[lllyasviel/sd-controlnet-normal](https://huggingface.co/lllyasviel/sd-controlnet-normal)<br/> *Trained with normal map* |A [normal mapped](https://en.wikipedia.org/wiki/Normal_mapping) image.|<a href="https://huggingface.co/takuma104/controlnet_dev/blob/main/gen_compare/control_images/converted/control_human_normal.png"><img width="64" src="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/control_images/converted/control_human_normal.png"/></a>|<a href="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/output_images/diffusers/output_human_normal_1.png"><img width="64" src="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/output_images/diffusers/output_human_normal_1.png"/></a>|
148+
|[lllyasviel/sd-controlnet_openpose](https://huggingface.co/lllyasviel/sd-controlnet_openpose)<br/> *Trained with OpenPose bone image* |A [OpenPose bone](https://github.com/CMU-Perceptual-Computing-Lab/openpose) image.|<a href="https://huggingface.co/takuma104/controlnet_dev/blob/main/gen_compare/control_images/converted/control_human_openpose.png"><img width="64" src="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/control_images/converted/control_human_openpose.png"/></a>|<a href="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/output_images/diffusers/output_human_openpose_0.png"><img width="64" src="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/output_images/diffusers/output_human_openpose_0.png"/></a>|
149+
|[lllyasviel/sd-controlnet_scribble](https://huggingface.co/lllyasviel/sd-controlnet_scribble)<br/> *Trained with human scribbles* |A hand-drawn monochrome image with white outlines on a black background.|<a href="https://huggingface.co/takuma104/controlnet_dev/blob/main/gen_compare/control_images/converted/control_vermeer_scribble.png"><img width="64" src="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/control_images/converted/control_vermeer_scribble.png"/></a>|<a href="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/output_images/diffusers/output_vermeer_scribble_0.png"><img width="64" src="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/output_images/diffusers/output_vermeer_scribble_0.png"/></a> |
150+
|[lllyasviel/sd-controlnet_seg](https://huggingface.co/lllyasviel/sd-controlnet_seg)<br/>*Trained with semantic segmentation* |An [ADE20K](https://groups.csail.mit.edu/vision/datasets/ADE20K/)'s segmentation protocol image.|<a href="https://huggingface.co/takuma104/controlnet_dev/blob/main/gen_compare/control_images/converted/control_room_seg.png"><img width="64" src="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/control_images/converted/control_room_seg.png"/></a>|<a href="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/output_images/diffusers/output_room_seg_1.png"><img width="64" src="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/output_images/diffusers/output_room_seg_1.png"/></a> |
151+
152+
[[autodoc]] StableDiffusionControlNetPipeline
153+
- all
154+
- __call__
155+
- enable_attention_slicing
156+
- disable_attention_slicing
157+
- enable_vae_slicing
158+
- disable_vae_slicing
159+
- enable_xformers_memory_efficient_attention
160+
- disable_xformers_memory_efficient_attention

docs/source/en/index.mdx

+1
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ available a colab notebook to directly try them out.
3636
|---|---|:---:|:---:|
3737
| [alt_diffusion](./api/pipelines/alt_diffusion) | [**AltDiffusion**](https://arxiv.org/abs/2211.06679) | Image-to-Image Text-Guided Generation |
3838
| [audio_diffusion](./api/pipelines/audio_diffusion) | [**Audio Diffusion**](https://github.com/teticio/audio-diffusion.git) | Unconditional Audio Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/teticio/audio-diffusion/blob/master/notebooks/audio_diffusion_pipeline.ipynb)
39+
| [controlnet](./api/pipelines/stable_diffusion/controlnet) | [**ControlNet with Stable Diffusion**](https://arxiv.org/abs/2302.05543) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1AiR7Q-sBqO88NCyswpfiuwXZc7DfMyKA?usp=sharing)
3940
| [cycle_diffusion](./api/pipelines/cycle_diffusion) | [**Cycle Diffusion**](https://arxiv.org/abs/2210.05559) | Image-to-Image Text-Guided Generation |
4041
| [dance_diffusion](./api/pipelines/dance_diffusion) | [**Dance Diffusion**](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation |
4142
| [ddpm](./api/pipelines/ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation |

docs/source/en/using-diffusers/controlling_generation.mdx

+12
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ Unless otherwise mentioned, these are techniques that work with existing models
3535
7. [MultiDiffusion Panorama](#multidiffusion-panorama)
3636
8. [DreamBooth](#dreambooth)
3737
9. [Textual Inversion](#textual-inversion)
38+
10. [ControlNet](#controlnet)
3839

3940
## Instruct Pix2Pix
4041

@@ -146,3 +147,14 @@ See [here](../training/dreambooth) for more information on how to use it.
146147
[Textual Inversion](../training/text_inversion) fine-tunes a model to teach it about a new concept. I.e. a few pictures of a style of artwork can be used to generate images in that style.
147148

148149
See [here](../training/text_inversion) for more information on how to use it.
150+
151+
## ControlNet
152+
153+
[Paper](https://arxiv.org/abs/2302.05543)
154+
155+
[ControlNet](../api/pipelines/stable_diffusion/controlnet) is an auxiliary network which adds an extra condition.
156+
There are 8 canonical pre-trained ControlNets trained on different conditionings such as edge detection, scribbles,
157+
depth maps, and semantic segmentations.
158+
159+
See [here](../api/pipelines/stable_diffusion/controlnet) for more information on how to use it.
160+

scripts/convert_original_stable_diffusion_to_diffusers.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,9 @@
120120
help="Path to the clip stats file. Only required if the stable unclip model's config specifies `model.params.noise_aug_config.params.clip_stats_path`.",
121121
required=False,
122122
)
123+
parser.add_argument(
124+
"--controlnet", action="store_true", default=None, help="Set flag if this is a controlnet checkpoint."
125+
)
123126
args = parser.parse_args()
124127

125128
pipe = load_pipeline_from_original_stable_diffusion_ckpt(
@@ -137,5 +140,11 @@
137140
stable_unclip=args.stable_unclip,
138141
stable_unclip_prior=args.stable_unclip_prior,
139142
clip_stats_path=args.clip_stats_path,
143+
controlnet=args.controlnet,
140144
)
141-
pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
145+
146+
if args.controlnet:
147+
# only save the controlnet model
148+
pipe.controlnet.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
149+
else:
150+
pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)

src/diffusers/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
else:
3535
from .models import (
3636
AutoencoderKL,
37+
ControlNetModel,
3738
ModelMixin,
3839
PriorTransformer,
3940
Transformer2DModel,
@@ -113,6 +114,7 @@
113114
PaintByExamplePipeline,
114115
SemanticStableDiffusionPipeline,
115116
StableDiffusionAttendAndExcitePipeline,
117+
StableDiffusionControlNetPipeline,
116118
StableDiffusionDepth2ImgPipeline,
117119
StableDiffusionImageVariationPipeline,
118120
StableDiffusionImg2ImgPipeline,

src/diffusers/models/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
if is_torch_available():
1919
from .autoencoder_kl import AutoencoderKL
20+
from .controlnet import ControlNetModel
2021
from .dual_transformer_2d import DualTransformer2DModel
2122
from .modeling_utils import ModelMixin
2223
from .prior_transformer import PriorTransformer

0 commit comments

Comments
 (0)