diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index f695dd3c1df9..b60933c05228 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -196,12 +196,12 @@
title: DDIM
- local: api/pipelines/ddpm
title: DDPM
+ - local: api/pipelines/deepfloyd_if
+ title: DeepFloyd IF
- local: api/pipelines/diffedit
title: DiffEdit
- local: api/pipelines/dit
title: DiT
- - local: api/pipelines/deepfloyd_if
- title: DeepFloyd IF
- local: api/pipelines/pix2pix
title: InstructPix2Pix
- local: api/pipelines/kandinsky
@@ -255,6 +255,8 @@
title: Super-Resolution
- local: api/pipelines/stable_diffusion/ldm3d_diffusion
title: LDM3D Text-to-(RGB, Depth)
+ - local: api/pipelines/stable_diffusion/adapter
+ title: Stable Diffusion T2I-adapter
title: Stable Diffusion
- local: api/pipelines/stable_unclip
title: Stable unCLIP
diff --git a/docs/source/en/api/pipelines/overview.mdx b/docs/source/en/api/pipelines/overview.mdx
index 693c32565c46..1d61ae6a1314 100644
--- a/docs/source/en/api/pipelines/overview.mdx
+++ b/docs/source/en/api/pipelines/overview.mdx
@@ -66,6 +66,7 @@ available a colab notebook to directly try them out.
| [score_sde_ve](./score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
| [score_sde_vp](./score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
| [semantic_stable_diffusion](./semantic_stable_diffusion) | [**SEGA: Instructing Diffusion using Semantic Dimensions**](https://arxiv.org/abs/2301.12247) | Text-to-Image Generation |
+| [stable_diffusion_adapter](./stable_diffusion/adapter) | [**T2I-Adapter**](https://arxiv.org/abs/2302.08453) | Image-to-Image Text-Guided Generation with Adapters | -
| [stable_diffusion_text2img](./stable_diffusion/text2img) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
| [stable_diffusion_img2img](./stable_diffusion/img2img) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
| [stable_diffusion_inpaint](./stable_diffusion/inpaint) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)
diff --git a/docs/source/en/api/pipelines/stable_diffusion/adapter.mdx b/docs/source/en/api/pipelines/stable_diffusion/adapter.mdx
new file mode 100644
index 000000000000..19351e1713b6
--- /dev/null
+++ b/docs/source/en/api/pipelines/stable_diffusion/adapter.mdx
@@ -0,0 +1,187 @@
+
+
+# Text-to-Image Generation with Adapter Conditioning
+
+## Overview
+
+[T2I-Adapter: Learning Adapters to Dig out More Controllable Ability for Text-to-Image Diffusion Models](https://arxiv.org/abs/2302.08453) by Chong Mou, Xintao Wang, Liangbin Xie, Jian Zhang, Zhongang Qi, Ying Shan, Xiaohu Qie.
+
+Using the pretrained models we can provide control images (for example, a depth map) to control Stable Diffusion text-to-image generation so that it follows the structure of the depth image and fills in the details.
+
+The abstract of the paper is the following:
+
+*The incredible generative ability of large-scale text-to-image (T2I) models has demonstrated strong power of learning complex structures and meaningful semantics. However, relying solely on text prompts cannot fully take advantage of the knowledge learned by the model, especially when flexible and accurate structure control is needed. In this paper, we aim to ``dig out" the capabilities that T2I models have implicitly learned, and then explicitly use them to control the generation more granularly. Specifically, we propose to learn simple and small T2I-Adapters to align internal knowledge in T2I models with external control signals, while freezing the original large T2I models. In this way, we can train various adapters according to different conditions, and achieve rich control and editing effects. Further, the proposed T2I-Adapters have attractive properties of practical value, such as composability and generalization ability. Extensive experiments demonstrate that our T2I-Adapter has promising generation quality and a wide range of applications.*
+
+This model was contributed by the community contributor [HimariO](https://github.com/HimariO) ❤️ .
+
+## Available Pipelines:
+
+| Pipeline | Tasks | Demo
+|---|---|:---:|
+| [StableDiffusionAdapterPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_adapter.py) | *Text-to-Image Generation with T2I-Adapter Conditioning* | -
+
+## Usage example
+
+In the following we give a simple example of how to use a *T2IAdapter* checkpoint with Diffusers for inference.
+All adapters use the same pipeline.
+
+ 1. Images are first converted into the appropriate *control image* format.
+ 2. The *control image* and *prompt* are passed to the [`StableDiffusionAdapterPipeline`].
+
+Let's have a look at a simple example using the [Color Adapter](https://huggingface.co/TencentARC/t2iadapter_color_sd14v1).
+
+```python
+from diffusers.utils import load_image
+
+image = load_image("https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/color_ref.png")
+```
+
+
+
+
+Then we can create our color palette by simply resizing it to 8 by 8 pixels and then scaling it back to original size.
+
+```python
+from PIL import Image
+
+color_palette = image.resize((8, 8))
+color_palette = color_palette.resize((512, 512), resample=Image.Resampling.NEAREST)
+```
+
+Let's take a look at the processed image.
+
+
+
+
+Next, create the adapter pipeline
+
+```py
+import torch
+from diffusers import StableDiffusionAdapterPipeline, T2IAdapter
+
+adapter = T2IAdapter.from_pretrained("TencentARC/t2iadapter_color_sd14v1")
+pipe = StableDiffusionAdapterPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4",
+ adapter=adapter,
+ torch_dtype=torch.float16,
+)
+pipe.to("cuda")
+```
+
+Finally, pass the prompt and control image to the pipeline
+
+```py
+# fix the random seed, so you will get the same result as the example
+generator = torch.manual_seed(7)
+
+out_image = pipe(
+ "At night, glowing cubes in front of the beach",
+ image=color_palette,
+ generator=generator,
+).images[0]
+```
+
+
+
+
+## Available checkpoints
+
+Non-diffusers checkpoints can be found under [TencentARC/T2I-Adapter](https://huggingface.co/TencentARC/T2I-Adapter/tree/main/models).
+
+### T2I-Adapter with Stable Diffusion 1.4
+
+| Model Name | Control Image Overview| Control Image Example | Generated Image Example |
+|---|---|---|---|
+|[TencentARC/t2iadapter_color_sd14v1](https://huggingface.co/TencentARC/t2iadapter_color_sd14v1)
*Trained with spatial color palette* | A image with 8x8 color palette.|
|
|
+|[TencentARC/t2iadapter_canny_sd14v1](https://huggingface.co/TencentARC/t2iadapter_canny_sd14v1)
*Trained with canny edge detection* | A monochrome image with white edges on a black background.|
|
|
+|[TencentARC/t2iadapter_sketch_sd14v1](https://huggingface.co/TencentARC/t2iadapter_sketch_sd14v1)
*Trained with [PidiNet](https://github.com/zhuoinoulu/pidinet) edge detection* | A hand-drawn monochrome image with white outlines on a black background.|
|
|
+|[TencentARC/t2iadapter_depth_sd14v1](https://huggingface.co/TencentARC/t2iadapter_depth_sd14v1)
*Trained with Midas depth estimation* | A grayscale image with black representing deep areas and white representing shallow areas.|
|
|
+|[TencentARC/t2iadapter_openpose_sd14v1](https://huggingface.co/TencentARC/t2iadapter_openpose_sd14v1)
*Trained with OpenPose bone image* | A [OpenPose bone](https://github.com/CMU-Perceptual-Computing-Lab/openpose) image.|
|
|
+|[TencentARC/t2iadapter_keypose_sd14v1](https://huggingface.co/TencentARC/t2iadapter_keypose_sd14v1)
*Trained with mmpose skeleton image* | A [mmpose skeleton](https://github.com/open-mmlab/mmpose) image.|
|
|
+|[TencentARC/t2iadapter_seg_sd14v1](https://huggingface.co/TencentARC/t2iadapter_seg_sd14v1)
*Trained with semantic segmentation* | An [custom](https://github.com/TencentARC/T2I-Adapter/discussions/25) segmentation protocol image.|
|
|
+|[TencentARC/t2iadapter_canny_sd15v2](https://huggingface.co/TencentARC/t2iadapter_canny_sd15v2)||
+|[TencentARC/t2iadapter_depth_sd15v2](https://huggingface.co/TencentARC/t2iadapter_depth_sd15v2)||
+|[TencentARC/t2iadapter_sketch_sd15v2](https://huggingface.co/TencentARC/t2iadapter_sketch_sd15v2)||
+|[TencentARC/t2iadapter_zoedepth_sd15v1](https://huggingface.co/TencentARC/t2iadapter_zoedepth_sd15v1)||
+
+## Combining multiple adapters
+
+[`MultiAdapter`] can be used for applying multiple conditionings at once.
+
+Here we use the keypose adapter for the character posture and the depth adapter for creating the scene.
+
+```py
+import torch
+from PIL import Image
+from diffusers.utils import load_image
+
+cond_keypose = load_image(
+ "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/keypose_sample_input.png"
+)
+cond_depth = load_image(
+ "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/depth_sample_input.png"
+)
+cond = [[cond_keypose, cond_depth]]
+
+prompt = ["A man walking in an office room with a nice view"]
+```
+
+The two control images look as such:
+
+
+
+
+
+`MultiAdapter` combines keypose and depth adapters.
+
+`adapter_conditioning_scale` balances the relative influence of the different adapters.
+
+```py
+from diffusers import StableDiffusionAdapterPipeline, MultiAdapter
+
+adapters = MultiAdapter(
+ [
+ T2IAdapter.from_pretrained("TencentARC/t2iadapter_keypose_sd14v1"),
+ T2IAdapter.from_pretrained("TencentARC/t2iadapter_depth_sd14v1"),
+ ]
+)
+adapters = adapters.to(torch.float16)
+
+pipe = StableDiffusionAdapterPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4",
+ torch_dtype=torch.float16,
+ adapter=adapters,
+)
+
+images = pipe(prompt, cond, adapter_conditioning_scale=[0.8, 0.8])
+```
+
+
+
+
+## T2I Adapter vs ControlNet
+
+T2I-Adapter is similar to [ControlNet](https://huggingface.co/docs/diffusers/main/en/api/pipelines/controlnet).
+T2i-Adapter uses a smaller auxiliary network which is only run once for the entire diffusion process.
+However, T2I-Adapter performs slightly worse than ControlNet.
+
+## StableDiffusionAdapterPipeline
+[[autodoc]] StableDiffusionAdapterPipeline
+ - all
+ - __call__
+ - enable_attention_slicing
+ - disable_attention_slicing
+ - enable_vae_slicing
+ - disable_vae_slicing
+ - enable_xformers_memory_efficient_attention
+ - disable_xformers_memory_efficient_attention
diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx
index 1bb52e628a04..f2012abc6970 100644
--- a/docs/source/en/index.mdx
+++ b/docs/source/en/index.mdx
@@ -69,6 +69,7 @@ The library has three main components:
| [score_sde_ve](./api/pipelines/score_sde_ve) | [Score-Based Generative Modeling through Stochastic Differential Equations](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
| [score_sde_vp](./api/pipelines/score_sde_vp) | [Score-Based Generative Modeling through Stochastic Differential Equations](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
| [semantic_stable_diffusion](./api/pipelines/semantic_stable_diffusion) | [Semantic Guidance](https://arxiv.org/abs/2301.12247) | Text-Guided Generation |
+| [stable_diffusion_adapter](./api/pipelines/stable_diffusion/adapter) | [**T2I-Adapter**](https://arxiv.org/abs/2302.08453) | Image-to-Image Text-Guided Generation | -
| [stable_diffusion_text2img](./api/pipelines/stable_diffusion/text2img) | [Stable Diffusion](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation |
| [stable_diffusion_img2img](./api/pipelines/stable_diffusion/img2img) | [Stable Diffusion](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation |
| [stable_diffusion_inpaint](./api/pipelines/stable_diffusion/inpaint) | [Stable Diffusion](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting |
diff --git a/docs/source/en/using-diffusers/controlling_generation.mdx b/docs/source/en/using-diffusers/controlling_generation.mdx
index 57b5640ffcd5..2660903517eb 100644
--- a/docs/source/en/using-diffusers/controlling_generation.mdx
+++ b/docs/source/en/using-diffusers/controlling_generation.mdx
@@ -59,6 +59,7 @@ For convenience, we provide a table to denote which methods are inference-only a
| [Custom Diffusion](#custom-diffusion) | ❌ | ✅ | |
| [Model Editing](#model-editing) | ✅ | ❌ | |
| [DiffEdit](#diffedit) | ✅ | ❌ | |
+| [T2I-Adapter](#t2i-adapter) | ✅ | ❌ | |
## Instruct Pix2Pix
@@ -215,4 +216,13 @@ To know more details, check out the [official doc](../api/pipelines/stable_diffu
[DiffEdit](../api/pipelines/stable_diffusion/diffedit) allows for semantic editing of input images along with
input prompts while preserving the original input images as much as possible.
-To know more details, check out the [official doc](../api/pipelines/stable_diffusion/model_editing).
\ No newline at end of file
+To know more details, check out the [official doc](../api/pipelines/stable_diffusion/model_editing).
+## T2I-Adapter
+
+[Paper](https://arxiv.org/abs/2302.08453)
+
+[T2I-Adapter](../api/pipelines/stable_diffusion/adapter) is an auxiliary network which adds an extra condition.
+There are 8 canonical pre-trained adapters trained on different conditionings such as edge detection, sketch,
+depth maps, and semantic segmentations.
+
+See [here](../api/pipelines/stable_diffusion/adapter) for more information on how to use it.
diff --git a/scripts/convert_original_t2i_adapter.py b/scripts/convert_original_t2i_adapter.py
new file mode 100644
index 000000000000..01a1fecf4e4b
--- /dev/null
+++ b/scripts/convert_original_t2i_adapter.py
@@ -0,0 +1,250 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Conversion script for the T2I-Adapter checkpoints.
+"""
+
+import argparse
+
+import torch
+
+from diffusers import T2IAdapter
+
+
+def convert_adapter(src_state, in_channels):
+ original_body_length = max([int(x.split(".")[1]) for x in src_state.keys() if "body." in x]) + 1
+
+ assert original_body_length == 8
+
+ # (0, 1) -> channels 1
+ assert src_state["body.0.block1.weight"].shape == (320, 320, 3, 3)
+
+ # (2, 3) -> channels 2
+ assert src_state["body.2.in_conv.weight"].shape == (640, 320, 1, 1)
+
+ # (4, 5) -> channels 3
+ assert src_state["body.4.in_conv.weight"].shape == (1280, 640, 1, 1)
+
+ # (6, 7) -> channels 4
+ assert src_state["body.6.block1.weight"].shape == (1280, 1280, 3, 3)
+
+ res_state = {
+ "adapter.conv_in.weight": src_state.pop("conv_in.weight"),
+ "adapter.conv_in.bias": src_state.pop("conv_in.bias"),
+ # 0.resnets.0
+ "adapter.body.0.resnets.0.block1.weight": src_state.pop("body.0.block1.weight"),
+ "adapter.body.0.resnets.0.block1.bias": src_state.pop("body.0.block1.bias"),
+ "adapter.body.0.resnets.0.block2.weight": src_state.pop("body.0.block2.weight"),
+ "adapter.body.0.resnets.0.block2.bias": src_state.pop("body.0.block2.bias"),
+ # 0.resnets.1
+ "adapter.body.0.resnets.1.block1.weight": src_state.pop("body.1.block1.weight"),
+ "adapter.body.0.resnets.1.block1.bias": src_state.pop("body.1.block1.bias"),
+ "adapter.body.0.resnets.1.block2.weight": src_state.pop("body.1.block2.weight"),
+ "adapter.body.0.resnets.1.block2.bias": src_state.pop("body.1.block2.bias"),
+ # 1
+ "adapter.body.1.in_conv.weight": src_state.pop("body.2.in_conv.weight"),
+ "adapter.body.1.in_conv.bias": src_state.pop("body.2.in_conv.bias"),
+ # 1.resnets.0
+ "adapter.body.1.resnets.0.block1.weight": src_state.pop("body.2.block1.weight"),
+ "adapter.body.1.resnets.0.block1.bias": src_state.pop("body.2.block1.bias"),
+ "adapter.body.1.resnets.0.block2.weight": src_state.pop("body.2.block2.weight"),
+ "adapter.body.1.resnets.0.block2.bias": src_state.pop("body.2.block2.bias"),
+ # 1.resnets.1
+ "adapter.body.1.resnets.1.block1.weight": src_state.pop("body.3.block1.weight"),
+ "adapter.body.1.resnets.1.block1.bias": src_state.pop("body.3.block1.bias"),
+ "adapter.body.1.resnets.1.block2.weight": src_state.pop("body.3.block2.weight"),
+ "adapter.body.1.resnets.1.block2.bias": src_state.pop("body.3.block2.bias"),
+ # 2
+ "adapter.body.2.in_conv.weight": src_state.pop("body.4.in_conv.weight"),
+ "adapter.body.2.in_conv.bias": src_state.pop("body.4.in_conv.bias"),
+ # 2.resnets.0
+ "adapter.body.2.resnets.0.block1.weight": src_state.pop("body.4.block1.weight"),
+ "adapter.body.2.resnets.0.block1.bias": src_state.pop("body.4.block1.bias"),
+ "adapter.body.2.resnets.0.block2.weight": src_state.pop("body.4.block2.weight"),
+ "adapter.body.2.resnets.0.block2.bias": src_state.pop("body.4.block2.bias"),
+ # 2.resnets.1
+ "adapter.body.2.resnets.1.block1.weight": src_state.pop("body.5.block1.weight"),
+ "adapter.body.2.resnets.1.block1.bias": src_state.pop("body.5.block1.bias"),
+ "adapter.body.2.resnets.1.block2.weight": src_state.pop("body.5.block2.weight"),
+ "adapter.body.2.resnets.1.block2.bias": src_state.pop("body.5.block2.bias"),
+ # 3.resnets.0
+ "adapter.body.3.resnets.0.block1.weight": src_state.pop("body.6.block1.weight"),
+ "adapter.body.3.resnets.0.block1.bias": src_state.pop("body.6.block1.bias"),
+ "adapter.body.3.resnets.0.block2.weight": src_state.pop("body.6.block2.weight"),
+ "adapter.body.3.resnets.0.block2.bias": src_state.pop("body.6.block2.bias"),
+ # 3.resnets.1
+ "adapter.body.3.resnets.1.block1.weight": src_state.pop("body.7.block1.weight"),
+ "adapter.body.3.resnets.1.block1.bias": src_state.pop("body.7.block1.bias"),
+ "adapter.body.3.resnets.1.block2.weight": src_state.pop("body.7.block2.weight"),
+ "adapter.body.3.resnets.1.block2.bias": src_state.pop("body.7.block2.bias"),
+ }
+
+ assert len(src_state) == 0
+
+ adapter = T2IAdapter(in_channels=in_channels, adapter_type="full_adapter")
+
+ adapter.load_state_dict(res_state)
+
+ return adapter
+
+
+def convert_light_adapter(src_state):
+ original_body_length = max([int(x.split(".")[1]) for x in src_state.keys() if "body." in x]) + 1
+
+ assert original_body_length == 4
+
+ res_state = {
+ # body.0.in_conv
+ "adapter.body.0.in_conv.weight": src_state.pop("body.0.in_conv.weight"),
+ "adapter.body.0.in_conv.bias": src_state.pop("body.0.in_conv.bias"),
+ # body.0.resnets.0
+ "adapter.body.0.resnets.0.block1.weight": src_state.pop("body.0.body.0.block1.weight"),
+ "adapter.body.0.resnets.0.block1.bias": src_state.pop("body.0.body.0.block1.bias"),
+ "adapter.body.0.resnets.0.block2.weight": src_state.pop("body.0.body.0.block2.weight"),
+ "adapter.body.0.resnets.0.block2.bias": src_state.pop("body.0.body.0.block2.bias"),
+ # body.0.resnets.1
+ "adapter.body.0.resnets.1.block1.weight": src_state.pop("body.0.body.1.block1.weight"),
+ "adapter.body.0.resnets.1.block1.bias": src_state.pop("body.0.body.1.block1.bias"),
+ "adapter.body.0.resnets.1.block2.weight": src_state.pop("body.0.body.1.block2.weight"),
+ "adapter.body.0.resnets.1.block2.bias": src_state.pop("body.0.body.1.block2.bias"),
+ # body.0.resnets.2
+ "adapter.body.0.resnets.2.block1.weight": src_state.pop("body.0.body.2.block1.weight"),
+ "adapter.body.0.resnets.2.block1.bias": src_state.pop("body.0.body.2.block1.bias"),
+ "adapter.body.0.resnets.2.block2.weight": src_state.pop("body.0.body.2.block2.weight"),
+ "adapter.body.0.resnets.2.block2.bias": src_state.pop("body.0.body.2.block2.bias"),
+ # body.0.resnets.3
+ "adapter.body.0.resnets.3.block1.weight": src_state.pop("body.0.body.3.block1.weight"),
+ "adapter.body.0.resnets.3.block1.bias": src_state.pop("body.0.body.3.block1.bias"),
+ "adapter.body.0.resnets.3.block2.weight": src_state.pop("body.0.body.3.block2.weight"),
+ "adapter.body.0.resnets.3.block2.bias": src_state.pop("body.0.body.3.block2.bias"),
+ # body.0.out_conv
+ "adapter.body.0.out_conv.weight": src_state.pop("body.0.out_conv.weight"),
+ "adapter.body.0.out_conv.bias": src_state.pop("body.0.out_conv.bias"),
+ # body.1.in_conv
+ "adapter.body.1.in_conv.weight": src_state.pop("body.1.in_conv.weight"),
+ "adapter.body.1.in_conv.bias": src_state.pop("body.1.in_conv.bias"),
+ # body.1.resnets.0
+ "adapter.body.1.resnets.0.block1.weight": src_state.pop("body.1.body.0.block1.weight"),
+ "adapter.body.1.resnets.0.block1.bias": src_state.pop("body.1.body.0.block1.bias"),
+ "adapter.body.1.resnets.0.block2.weight": src_state.pop("body.1.body.0.block2.weight"),
+ "adapter.body.1.resnets.0.block2.bias": src_state.pop("body.1.body.0.block2.bias"),
+ # body.1.resnets.1
+ "adapter.body.1.resnets.1.block1.weight": src_state.pop("body.1.body.1.block1.weight"),
+ "adapter.body.1.resnets.1.block1.bias": src_state.pop("body.1.body.1.block1.bias"),
+ "adapter.body.1.resnets.1.block2.weight": src_state.pop("body.1.body.1.block2.weight"),
+ "adapter.body.1.resnets.1.block2.bias": src_state.pop("body.1.body.1.block2.bias"),
+ # body.1.body.2
+ "adapter.body.1.resnets.2.block1.weight": src_state.pop("body.1.body.2.block1.weight"),
+ "adapter.body.1.resnets.2.block1.bias": src_state.pop("body.1.body.2.block1.bias"),
+ "adapter.body.1.resnets.2.block2.weight": src_state.pop("body.1.body.2.block2.weight"),
+ "adapter.body.1.resnets.2.block2.bias": src_state.pop("body.1.body.2.block2.bias"),
+ # body.1.body.3
+ "adapter.body.1.resnets.3.block1.weight": src_state.pop("body.1.body.3.block1.weight"),
+ "adapter.body.1.resnets.3.block1.bias": src_state.pop("body.1.body.3.block1.bias"),
+ "adapter.body.1.resnets.3.block2.weight": src_state.pop("body.1.body.3.block2.weight"),
+ "adapter.body.1.resnets.3.block2.bias": src_state.pop("body.1.body.3.block2.bias"),
+ # body.1.out_conv
+ "adapter.body.1.out_conv.weight": src_state.pop("body.1.out_conv.weight"),
+ "adapter.body.1.out_conv.bias": src_state.pop("body.1.out_conv.bias"),
+ # body.2.in_conv
+ "adapter.body.2.in_conv.weight": src_state.pop("body.2.in_conv.weight"),
+ "adapter.body.2.in_conv.bias": src_state.pop("body.2.in_conv.bias"),
+ # body.2.body.0
+ "adapter.body.2.resnets.0.block1.weight": src_state.pop("body.2.body.0.block1.weight"),
+ "adapter.body.2.resnets.0.block1.bias": src_state.pop("body.2.body.0.block1.bias"),
+ "adapter.body.2.resnets.0.block2.weight": src_state.pop("body.2.body.0.block2.weight"),
+ "adapter.body.2.resnets.0.block2.bias": src_state.pop("body.2.body.0.block2.bias"),
+ # body.2.body.1
+ "adapter.body.2.resnets.1.block1.weight": src_state.pop("body.2.body.1.block1.weight"),
+ "adapter.body.2.resnets.1.block1.bias": src_state.pop("body.2.body.1.block1.bias"),
+ "adapter.body.2.resnets.1.block2.weight": src_state.pop("body.2.body.1.block2.weight"),
+ "adapter.body.2.resnets.1.block2.bias": src_state.pop("body.2.body.1.block2.bias"),
+ # body.2.body.2
+ "adapter.body.2.resnets.2.block1.weight": src_state.pop("body.2.body.2.block1.weight"),
+ "adapter.body.2.resnets.2.block1.bias": src_state.pop("body.2.body.2.block1.bias"),
+ "adapter.body.2.resnets.2.block2.weight": src_state.pop("body.2.body.2.block2.weight"),
+ "adapter.body.2.resnets.2.block2.bias": src_state.pop("body.2.body.2.block2.bias"),
+ # body.2.body.3
+ "adapter.body.2.resnets.3.block1.weight": src_state.pop("body.2.body.3.block1.weight"),
+ "adapter.body.2.resnets.3.block1.bias": src_state.pop("body.2.body.3.block1.bias"),
+ "adapter.body.2.resnets.3.block2.weight": src_state.pop("body.2.body.3.block2.weight"),
+ "adapter.body.2.resnets.3.block2.bias": src_state.pop("body.2.body.3.block2.bias"),
+ # body.2.out_conv
+ "adapter.body.2.out_conv.weight": src_state.pop("body.2.out_conv.weight"),
+ "adapter.body.2.out_conv.bias": src_state.pop("body.2.out_conv.bias"),
+ # body.3.in_conv
+ "adapter.body.3.in_conv.weight": src_state.pop("body.3.in_conv.weight"),
+ "adapter.body.3.in_conv.bias": src_state.pop("body.3.in_conv.bias"),
+ # body.3.body.0
+ "adapter.body.3.resnets.0.block1.weight": src_state.pop("body.3.body.0.block1.weight"),
+ "adapter.body.3.resnets.0.block1.bias": src_state.pop("body.3.body.0.block1.bias"),
+ "adapter.body.3.resnets.0.block2.weight": src_state.pop("body.3.body.0.block2.weight"),
+ "adapter.body.3.resnets.0.block2.bias": src_state.pop("body.3.body.0.block2.bias"),
+ # body.3.body.1
+ "adapter.body.3.resnets.1.block1.weight": src_state.pop("body.3.body.1.block1.weight"),
+ "adapter.body.3.resnets.1.block1.bias": src_state.pop("body.3.body.1.block1.bias"),
+ "adapter.body.3.resnets.1.block2.weight": src_state.pop("body.3.body.1.block2.weight"),
+ "adapter.body.3.resnets.1.block2.bias": src_state.pop("body.3.body.1.block2.bias"),
+ # body.3.body.2
+ "adapter.body.3.resnets.2.block1.weight": src_state.pop("body.3.body.2.block1.weight"),
+ "adapter.body.3.resnets.2.block1.bias": src_state.pop("body.3.body.2.block1.bias"),
+ "adapter.body.3.resnets.2.block2.weight": src_state.pop("body.3.body.2.block2.weight"),
+ "adapter.body.3.resnets.2.block2.bias": src_state.pop("body.3.body.2.block2.bias"),
+ # body.3.body.3
+ "adapter.body.3.resnets.3.block1.weight": src_state.pop("body.3.body.3.block1.weight"),
+ "adapter.body.3.resnets.3.block1.bias": src_state.pop("body.3.body.3.block1.bias"),
+ "adapter.body.3.resnets.3.block2.weight": src_state.pop("body.3.body.3.block2.weight"),
+ "adapter.body.3.resnets.3.block2.bias": src_state.pop("body.3.body.3.block2.bias"),
+ # body.3.out_conv
+ "adapter.body.3.out_conv.weight": src_state.pop("body.3.out_conv.weight"),
+ "adapter.body.3.out_conv.bias": src_state.pop("body.3.out_conv.bias"),
+ }
+
+ assert len(src_state) == 0
+
+ adapter = T2IAdapter(in_channels=3, channels=[320, 640, 1280], num_res_blocks=4, adapter_type="light_adapter")
+
+ adapter.load_state_dict(res_state)
+
+ return adapter
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
+ )
+ parser.add_argument(
+ "--output_path", default=None, type=str, required=True, help="Path to the store the result checkpoint."
+ )
+ parser.add_argument(
+ "--is_adapter_light",
+ action="store_true",
+ help="Is checkpoint come from Adapter-Light architecture. ex: color-adapter",
+ )
+ parser.add_argument("--in_channels", required=False, type=int, help="Input channels for non-light adapter")
+
+ args = parser.parse_args()
+ src_state = torch.load(args.checkpoint_path)
+
+ if args.is_adapter_light:
+ adapter = convert_light_adapter(src_state)
+ else:
+ if args.in_channels is None:
+ raise ValueError("set `--in_channels=`")
+ adapter = convert_adapter(src_state, args.in_channels)
+
+ adapter.save_pretrained(args.output_path)
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index 39178edc00d1..0c0869fb52bd 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -39,7 +39,9 @@
AutoencoderKL,
ControlNetModel,
ModelMixin,
+ MultiAdapter,
PriorTransformer,
+ T2IAdapter,
T5FilmDecoder,
Transformer2DModel,
UNet1DModel,
@@ -151,6 +153,7 @@
SemanticStableDiffusionPipeline,
ShapEImg2ImgPipeline,
ShapEPipeline,
+ StableDiffusionAdapterPipeline,
StableDiffusionAttendAndExcitePipeline,
StableDiffusionControlNetImg2ImgPipeline,
StableDiffusionControlNetInpaintPipeline,
diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py
index 23839c84af45..6e330a44691a 100644
--- a/src/diffusers/models/__init__.py
+++ b/src/diffusers/models/__init__.py
@@ -16,6 +16,7 @@
if is_torch_available():
+ from .adapter import MultiAdapter, T2IAdapter
from .autoencoder_kl import AutoencoderKL
from .controlnet import ControlNetModel
from .dual_transformer_2d import DualTransformer2DModel
diff --git a/src/diffusers/models/adapter.py b/src/diffusers/models/adapter.py
new file mode 100644
index 000000000000..a65a3873b130
--- /dev/null
+++ b/src/diffusers/models/adapter.py
@@ -0,0 +1,291 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Optional
+
+import torch
+import torch.nn as nn
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from .modeling_utils import ModelMixin
+from .resnet import Downsample2D
+
+
+class MultiAdapter(ModelMixin):
+ r"""
+ MultiAdapter is a wrapper model that contains multiple adapter models and merges their outputs according to
+ user-assigned weighting.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
+ implements for all the model (such as downloading or saving, etc.)
+
+ Parameters:
+ adapters (`List[T2IAdapter]`, *optional*, defaults to None):
+ A list of `T2IAdapter` model instances.
+ """
+
+ def __init__(self, adapters: List["T2IAdapter"]):
+ super(MultiAdapter, self).__init__()
+
+ self.num_adapter = len(adapters)
+ self.adapters = nn.ModuleList(adapters)
+
+ def forward(self, xs: torch.Tensor, adapter_weights: Optional[List[float]] = None) -> List[torch.Tensor]:
+ r"""
+ Args:
+ xs (`torch.Tensor`):
+ (batch, channel, height, width) input images for multiple adapter models concated along dimension 1,
+ `channel` should equal to `num_adapter` * "number of channel of image".
+ adapter_weights (`List[float]`, *optional*, defaults to None):
+ List of floats representing the weight which will be multiply to each adapter's output before adding
+ them together.
+ """
+ if adapter_weights is None:
+ adapter_weights = torch.tensor([1 / self.num_adapter] * self.num_adapter)
+ else:
+ adapter_weights = torch.tensor(adapter_weights)
+
+ if xs.shape[1] % self.num_adapter != 0:
+ raise ValueError(
+ f"Expecting multi-adapter's input have number of channel that cab be evenly divisible "
+ f"by num_adapter: {xs.shape[1]} % {self.num_adapter} != 0"
+ )
+ x_list = torch.chunk(xs, self.num_adapter, dim=1)
+ accume_state = None
+ for x, w, adapter in zip(x_list, adapter_weights, self.adapters):
+ features = adapter(x)
+ if accume_state is None:
+ accume_state = features
+ else:
+ for i in range(len(features)):
+ accume_state[i] += w * features[i]
+ return accume_state
+
+
+class T2IAdapter(ModelMixin, ConfigMixin):
+ r"""
+ A simple ResNet-like model that accepts images containing control signals such as keyposes and depth. The model
+ generates multiple feature maps that are used as additional conditioning in [`UNet2DConditionModel`]. The model's
+ architecture follows the original implementation of
+ [Adapter](https://github.com/TencentARC/T2I-Adapter/blob/686de4681515662c0ac2ffa07bf5dda83af1038a/ldm/modules/encoders/adapter.py#L97)
+ and
+ [AdapterLight](https://github.com/TencentARC/T2I-Adapter/blob/686de4681515662c0ac2ffa07bf5dda83af1038a/ldm/modules/encoders/adapter.py#L235).
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
+ implements for all the model (such as downloading or saving, etc.)
+
+ Parameters:
+ in_channels (`int`, *optional*, defaults to 3):
+ Number of channels of Aapter's input(*control image*). Set this parameter to 1 if you're using gray scale
+ image as *control image*.
+ channels (`List[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
+ The number of channel of each downsample block's output hidden state. The `len(block_out_channels)` will
+ also determine the number of downsample blocks in the Adapter.
+ num_res_blocks (`int`, *optional*, defaults to 2):
+ Number of ResNet blocks in each downsample block
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 3,
+ channels: List[int] = [320, 640, 1280, 1280],
+ num_res_blocks: int = 2,
+ downscale_factor: int = 8,
+ adapter_type: str = "full_adapter",
+ ):
+ super().__init__()
+
+ if adapter_type == "full_adapter":
+ self.adapter = FullAdapter(in_channels, channels, num_res_blocks, downscale_factor)
+ elif adapter_type == "light_adapter":
+ self.adapter = LightAdapter(in_channels, channels, num_res_blocks, downscale_factor)
+ else:
+ raise ValueError(f"unknown adapter_type: {type}. Choose either 'full_adapter' or 'simple_adapter'")
+
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
+ return self.adapter(x)
+
+ @property
+ def total_downscale_factor(self):
+ return self.adapter.total_downscale_factor
+
+
+# full adapter
+
+
+class FullAdapter(nn.Module):
+ def __init__(
+ self,
+ in_channels: int = 3,
+ channels: List[int] = [320, 640, 1280, 1280],
+ num_res_blocks: int = 2,
+ downscale_factor: int = 8,
+ ):
+ super().__init__()
+
+ in_channels = in_channels * downscale_factor**2
+
+ self.unshuffle = nn.PixelUnshuffle(downscale_factor)
+ self.conv_in = nn.Conv2d(in_channels, channels[0], kernel_size=3, padding=1)
+
+ self.body = nn.ModuleList(
+ [
+ AdapterBlock(channels[0], channels[0], num_res_blocks),
+ *[
+ AdapterBlock(channels[i - 1], channels[i], num_res_blocks, down=True)
+ for i in range(1, len(channels))
+ ],
+ ]
+ )
+
+ self.total_downscale_factor = downscale_factor * 2 ** (len(channels) - 1)
+
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
+ x = self.unshuffle(x)
+ x = self.conv_in(x)
+
+ features = []
+
+ for block in self.body:
+ x = block(x)
+ features.append(x)
+
+ return features
+
+
+class AdapterBlock(nn.Module):
+ def __init__(self, in_channels, out_channels, num_res_blocks, down=False):
+ super().__init__()
+
+ self.downsample = None
+ if down:
+ self.downsample = Downsample2D(in_channels)
+
+ self.in_conv = None
+ if in_channels != out_channels:
+ self.in_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
+
+ self.resnets = nn.Sequential(
+ *[AdapterResnetBlock(out_channels) for _ in range(num_res_blocks)],
+ )
+
+ def forward(self, x):
+ if self.downsample is not None:
+ x = self.downsample(x)
+
+ if self.in_conv is not None:
+ x = self.in_conv(x)
+
+ x = self.resnets(x)
+
+ return x
+
+
+class AdapterResnetBlock(nn.Module):
+ def __init__(self, channels):
+ super().__init__()
+ self.block1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
+ self.act = nn.ReLU()
+ self.block2 = nn.Conv2d(channels, channels, kernel_size=1)
+
+ def forward(self, x):
+ h = x
+ h = self.block1(h)
+ h = self.act(h)
+ h = self.block2(h)
+
+ return h + x
+
+
+# light adapter
+
+
+class LightAdapter(nn.Module):
+ def __init__(
+ self,
+ in_channels: int = 3,
+ channels: List[int] = [320, 640, 1280],
+ num_res_blocks: int = 4,
+ downscale_factor: int = 8,
+ ):
+ super().__init__()
+
+ in_channels = in_channels * downscale_factor**2
+
+ self.unshuffle = nn.PixelUnshuffle(downscale_factor)
+
+ self.body = nn.ModuleList(
+ [
+ LightAdapterBlock(in_channels, channels[0], num_res_blocks),
+ *[
+ LightAdapterBlock(channels[i], channels[i + 1], num_res_blocks, down=True)
+ for i in range(len(channels) - 1)
+ ],
+ LightAdapterBlock(channels[-1], channels[-1], num_res_blocks, down=True),
+ ]
+ )
+
+ self.total_downscale_factor = downscale_factor * (2 ** len(channels))
+
+ def forward(self, x):
+ x = self.unshuffle(x)
+
+ features = []
+
+ for block in self.body:
+ x = block(x)
+ features.append(x)
+
+ return features
+
+
+class LightAdapterBlock(nn.Module):
+ def __init__(self, in_channels, out_channels, num_res_blocks, down=False):
+ super().__init__()
+ mid_channels = out_channels // 4
+
+ self.downsample = None
+ if down:
+ self.downsample = Downsample2D(in_channels)
+
+ self.in_conv = nn.Conv2d(in_channels, mid_channels, kernel_size=1)
+ self.resnets = nn.Sequential(*[LightAdapterResnetBlock(mid_channels) for _ in range(num_res_blocks)])
+ self.out_conv = nn.Conv2d(mid_channels, out_channels, kernel_size=1)
+
+ def forward(self, x):
+ if self.downsample is not None:
+ x = self.downsample(x)
+
+ x = self.in_conv(x)
+ x = self.resnets(x)
+ x = self.out_conv(x)
+
+ return x
+
+
+class LightAdapterResnetBlock(nn.Module):
+ def __init__(self, channels):
+ super().__init__()
+ self.block1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
+ self.act = nn.ReLU()
+ self.block2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
+
+ def forward(self, x):
+ h = x
+ h = self.block1(h)
+ h = self.act(h)
+ h = self.block2(h)
+
+ return h + x
diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py
index cb3452f4459c..469e501b814b 100644
--- a/src/diffusers/models/unet_2d_blocks.py
+++ b/src/diffusers/models/unet_2d_blocks.py
@@ -955,10 +955,13 @@ def forward(
attention_mask: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ additional_residuals=None,
):
output_states = ()
- for resnet, attn in zip(self.resnets, self.attentions):
+ blocks = list(zip(self.resnets, self.attentions))
+
+ for i, (resnet, attn) in enumerate(blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
@@ -999,6 +1002,10 @@ def custom_forward(*inputs):
return_dict=False,
)[0]
+ # apply additional residuals to the output of the last pair of resnet and attention blocks
+ if i == len(blocks) - 1 and additional_residuals is not None:
+ hidden_states = hidden_states + additional_residuals
+
output_states = output_states + (hidden_states,)
if self.downsamplers is not None:
diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py
index dee71bead0f9..d7756ab5edb3 100644
--- a/src/diffusers/models/unet_2d_condition.py
+++ b/src/diffusers/models/unet_2d_condition.py
@@ -899,9 +899,18 @@ def forward(
sample = self.conv_in(sample)
# 3. down
+
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
+ is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
+
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
+ # For t2i-adapter CrossAttnDownBlock2D
+ additional_residuals = {}
+ if is_adapter and len(down_block_additional_residuals) > 0:
+ additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
+
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
@@ -909,13 +918,17 @@ def forward(
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
encoder_attention_mask=encoder_attention_mask,
+ **additional_residuals,
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
+ if is_adapter and len(down_block_additional_residuals) > 0:
+ sample += down_block_additional_residuals.pop(0)
+
down_block_res_samples += res_samples
- if down_block_additional_residuals is not None:
+ if is_controlnet:
new_down_block_res_samples = ()
for down_block_res_sample, down_block_additional_residual in zip(
@@ -937,7 +950,7 @@ def forward(
encoder_attention_mask=encoder_attention_mask,
)
- if mid_block_additional_residual is not None:
+ if is_controlnet:
sample = sample + mid_block_additional_residual
# 5. up
diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py
index 937ac1b5e3d7..aa09e7e81130 100644
--- a/src/diffusers/pipelines/__init__.py
+++ b/src/diffusers/pipelines/__init__.py
@@ -101,6 +101,7 @@
StableUnCLIPPipeline,
)
from .stable_diffusion_safe import StableDiffusionPipelineSafe
+ from .t2i_adapter import StableDiffusionAdapterPipeline
from .text_to_video_synthesis import TextToVideoSDPipeline, TextToVideoZeroPipeline, VideoToVideoSDPipeline
from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline
from .unidiffuser import ImageTextPipelineOutput, UniDiffuserModel, UniDiffuserPipeline, UniDiffuserTextDecoder
diff --git a/src/diffusers/pipelines/t2i_adapter/__init__.py b/src/diffusers/pipelines/t2i_adapter/__init__.py
new file mode 100644
index 000000000000..c4de661dbefa
--- /dev/null
+++ b/src/diffusers/pipelines/t2i_adapter/__init__.py
@@ -0,0 +1,14 @@
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
+else:
+ from .pipeline_stable_diffusion_adapter import StableDiffusionAdapterPipeline
diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py
new file mode 100644
index 000000000000..c84c99e6d19a
--- /dev/null
+++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py
@@ -0,0 +1,818 @@
+# Copyright 2023 TencentARC and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import warnings
+from dataclasses import dataclass
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import PIL
+import torch
+from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
+
+from ...image_processor import VaeImageProcessor
+from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
+from ...models import AutoencoderKL, MultiAdapter, T2IAdapter, UNet2DConditionModel
+from ...schedulers import KarrasDiffusionSchedulers
+from ...utils import (
+ PIL_INTERPOLATION,
+ BaseOutput,
+ is_accelerate_available,
+ is_accelerate_version,
+ logging,
+ randn_tensor,
+ replace_example_docstring,
+)
+from ..pipeline_utils import DiffusionPipeline
+from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+
+
+@dataclass
+class StableDiffusionAdapterPipelineOutput(BaseOutput):
+ """
+ Args:
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
+ nsfw_content_detected (`List[bool]`)
+ List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, or `None` if safety checking could not be performed.
+ """
+
+ images: Union[List[PIL.Image.Image], np.ndarray]
+ nsfw_content_detected: Optional[List[bool]]
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> from PIL import Image
+ >>> from diffusers.utils import load_image
+ >>> import torch
+ >>> from diffusers import StableDiffusionAdapterPipeline, T2IAdapter
+
+ >>> image = load_image(
+ ... "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/color_ref.png"
+ ... )
+
+ >>> color_palette = image.resize((8, 8))
+ >>> color_palette = color_palette.resize((512, 512), resample=Image.Resampling.NEAREST)
+
+ >>> adapter = T2IAdapter.from_pretrained("TencentARC/t2iadapter_color_sd14v1", torch_dtype=torch.float16)
+ >>> pipe = StableDiffusionAdapterPipeline.from_pretrained(
+ ... "CompVis/stable-diffusion-v1-4",
+ ... adapter=adapter,
+ ... torch_dtype=torch.float16,
+ ... )
+
+ >>> pipe.to("cuda")
+
+ >>> out_image = pipe(
+ ... "At night, glowing cubes in front of the beach",
+ ... image=color_palette,
+ ... ).images[0]
+ ```
+"""
+
+
+def _preprocess_adapter_image(image, height, width):
+ if isinstance(image, torch.Tensor):
+ return image
+ elif isinstance(image, PIL.Image.Image):
+ image = [image]
+
+ if isinstance(image[0], PIL.Image.Image):
+ image = [np.array(i.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])) for i in image]
+ image = [
+ i[None, ..., None] if i.ndim == 2 else i[None, ...] for i in image
+ ] # expand [h, w] or [h, w, c] to [b, h, w, c]
+ image = np.concatenate(image, axis=0)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image.transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image)
+ elif isinstance(image[0], torch.Tensor):
+ if image[0].ndim == 3:
+ image = torch.stack(image, dim=0)
+ elif image[0].ndim == 4:
+ image = torch.cat(image, dim=0)
+ else:
+ raise ValueError(
+ f"Invalid image tensor! Expecting image tensor with 3 or 4 dimension, but recive: {image[0].ndim}"
+ )
+ return image
+
+
+class StableDiffusionAdapterPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion augmented with T2I-Adapter
+ https://arxiv.org/abs/2302.08453
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ adapter ([`T2IAdapter`] or [`MultiAdapter`] or `List[T2IAdapter]`):
+ Provides additional conditioning to the unet during the denoising process. If you set multiple Adapter as a
+ list, the outputs from each Adapter are added together to create one combined additional conditioning.
+ adapter_weights (`List[float]`, *optional*, defaults to None):
+ List of floats representing the weight which will be multiply to each adapter's output before adding them
+ together.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPFeatureExtractor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ adapter: Union[T2IAdapter, MultiAdapter, List[T2IAdapter]],
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPFeatureExtractor,
+ adapter_weights: Optional[List[float]] = None,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ if isinstance(adapter, (list, tuple)):
+ adapter = MultiAdapter(adapter, adapter_weights=adapter_weights)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ adapter=adapter,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding.
+
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
+ steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ def enable_sequential_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
+ Note that offloading happens on a submodule basis. Memory savings are higher than with
+ `enable_model_cpu_offload`, but performance is lower.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
+ from accelerate import cpu_offload
+ else:
+ raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ if self.device.type != "cpu":
+ self.to("cpu", silence_dtype_warnings=True)
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.adapter]:
+ cpu_offload(cpu_offloaded_model, device)
+
+ if self.safety_checker is not None:
+ cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
+
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate import cpu_offload_with_hook
+ else:
+ raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ if self.device.type != "cpu":
+ self.to("cpu", silence_dtype_warnings=True)
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+
+ hook = None
+ for cpu_offloaded_model in [self.text_encoder, self.adapter, self.unet, self.vae]:
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
+
+ if self.safety_checker is not None:
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
+
+ # We'll offload the last model manually.
+ self.final_offload_hook = hook
+
+ @property
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
+ def _execution_device(self):
+ r"""
+ Returns the device on which the pipeline's models will be executed. After calling
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
+ hooks.
+ """
+ if not hasattr(self.unet, "_hf_hook"):
+ return self.device
+ for module in self.unet.modules():
+ if (
+ hasattr(module, "_hf_hook")
+ and hasattr(module._hf_hook, "execution_device")
+ and module._hf_hook.execution_device is not None
+ ):
+ return torch.device(module._hf_hook.execution_device)
+ return self.device
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is None:
+ has_nsfw_concept = None
+ else:
+ if torch.is_tensor(image):
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
+ else:
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ return image, has_nsfw_concept
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
+ def decode_latents(self, latents):
+ warnings.warn(
+ "The decode_latents method is deprecated and will be removed in a future version. Please"
+ " use VaeImageProcessor instead",
+ FutureWarning,
+ )
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ def _default_height_width(self, height, width, image):
+ # NOTE: It is possible that a list of images have different
+ # dimensions for each image, so just checking the first image
+ # is not _exactly_ correct, but it is simple.
+ while isinstance(image, list):
+ image = image[0]
+
+ if height is None:
+ if isinstance(image, PIL.Image.Image):
+ height = image.height
+ elif isinstance(image, torch.Tensor):
+ height = image.shape[-2]
+
+ # round down to nearest multiple of `self.adapter.total_downscale_factor`
+ height = (height // self.adapter.total_downscale_factor) * self.adapter.total_downscale_factor
+
+ if width is None:
+ if isinstance(image, PIL.Image.Image):
+ width = image.width
+ elif isinstance(image, torch.Tensor):
+ width = image.shape[-1]
+
+ # round down to nearest multiple of `self.adapter.total_downscale_factor`
+ width = (width // self.adapter.total_downscale_factor) * self.adapter.total_downscale_factor
+
+ return height, width
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ image: Union[torch.Tensor, PIL.Image.Image, List[PIL.Image.Image]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ adapter_conditioning_scale: Union[float, List[float]] = 1.0,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]` or `List[PIL.Image.Image]` or `List[List[PIL.Image.Image]]`):
+ The Adapter input condition. Adapter uses this input condition to generate guidance to Unet. If the
+ type is specified as `Torch.FloatTensor`, it is passed to Adapter as is. PIL.Image.Image` can also be
+ accepted as an image. The control image is automatically resized to fit the output image.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionAdapterPipelineOutput`] instead
+ of a plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
+ `self.processor` in
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
+ adapter_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The outputs of the adapter are multiplied by `adapter_conditioning_scale` before they are added to the
+ residual in the original unet. If multiple adapters are specified in init, you can set the
+ corresponding scale as a list.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionAdapterPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionAdapterPipelineOutput`] if `return_dict` is True, otherwise a
+ `tuple. When returning a tuple, the first element is a list with the generated images, and the second
+ element is a list of `bool`s denoting whether the corresponding generated image likely represents
+ "not-safe-for-work" (nsfw) content, according to the `safety_checker`.
+ """
+ # 0. Default height and width to unet
+ height, width = self._default_height_width(height, width, image)
+ device = self._execution_device
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
+ )
+
+ is_multi_adapter = isinstance(self.adapter, MultiAdapter)
+ if is_multi_adapter:
+ adapter_input = [_preprocess_adapter_image(img, height, width).to(device) for img in image]
+ n, c, h, w = adapter_input[0].shape
+ adapter_input = torch.stack([x.reshape([n * c, h, w]) for x in adapter_input])
+ else:
+ adapter_input = _preprocess_adapter_image(image, height, width).to(device)
+ adapter_input = adapter_input.to(self.adapter.dtype)
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Denoising loop
+ adapter_state = self.adapter(adapter_input)
+ for k, v in enumerate(adapter_state):
+ adapter_state[k] = v * adapter_conditioning_scale
+ if num_images_per_prompt > 1:
+ for k, v in enumerate(adapter_state):
+ adapter_state[k] = v.repeat(num_images_per_prompt, 1, 1, 1)
+ if do_classifier_free_guidance:
+ for k, v in enumerate(adapter_state):
+ adapter_state[k] = torch.cat([v] * 2, dim=0)
+
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ down_block_additional_residuals=[state.clone() for state in adapter_state],
+ ).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ if output_type == "latent":
+ image = latents
+ has_nsfw_concept = None
+ elif output_type == "pil":
+ # 8. Post-processing
+ image = self.decode_latents(latents)
+
+ # 9. Run safety checker
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+
+ # 10. Convert to PIL
+ image = self.numpy_to_pil(image)
+ else:
+ # 8. Post-processing
+ image = self.decode_latents(latents)
+
+ # 9. Run safety checker
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionAdapterPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
index 0a2fad6aee1a..82628104eba2 100644
--- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
+++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
@@ -1012,9 +1012,18 @@ def forward(
sample = self.conv_in(sample)
# 3. down
+
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
+ is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
+
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
+ # For t2i-adapter CrossAttnDownBlockFlat
+ additional_residuals = {}
+ if is_adapter and len(down_block_additional_residuals) > 0:
+ additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
+
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
@@ -1022,13 +1031,17 @@ def forward(
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
encoder_attention_mask=encoder_attention_mask,
+ **additional_residuals,
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
+ if is_adapter and len(down_block_additional_residuals) > 0:
+ sample += down_block_additional_residuals.pop(0)
+
down_block_res_samples += res_samples
- if down_block_additional_residuals is not None:
+ if is_controlnet:
new_down_block_res_samples = ()
for down_block_res_sample, down_block_additional_residual in zip(
@@ -1050,7 +1063,7 @@ def forward(
encoder_attention_mask=encoder_attention_mask,
)
- if mid_block_additional_residual is not None:
+ if is_controlnet:
sample = sample + mid_block_additional_residual
# 5. up
@@ -1390,10 +1403,13 @@ def forward(
attention_mask: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ additional_residuals=None,
):
output_states = ()
- for resnet, attn in zip(self.resnets, self.attentions):
+ blocks = list(zip(self.resnets, self.attentions))
+
+ for i, (resnet, attn) in enumerate(blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
@@ -1434,6 +1450,10 @@ def custom_forward(*inputs):
return_dict=False,
)[0]
+ # apply additional residuals to the output of the last pair of resnet and attention blocks
+ if i == len(blocks) - 1 and additional_residuals is not None:
+ hidden_states = hidden_states + additional_residuals
+
output_states = output_states + (hidden_states,)
if self.downsamplers is not None:
diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py
index 20dbf84681d3..b955ec5320de 100644
--- a/src/diffusers/utils/dummy_pt_objects.py
+++ b/src/diffusers/utils/dummy_pt_objects.py
@@ -47,6 +47,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class MultiAdapter(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class PriorTransformer(metaclass=DummyObject):
_backends = ["torch"]
@@ -62,6 +77,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class T2IAdapter(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class T5FilmDecoder(metaclass=DummyObject):
_backends = ["torch"]
diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
index 164206d776fa..016760337c69 100644
--- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py
+++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
@@ -407,6 +407,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
+class StableDiffusionAdapterPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class StableDiffusionAttendAndExcitePipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_adapter.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_adapter.py
new file mode 100644
index 000000000000..0c1dd1cfe87b
--- /dev/null
+++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_adapter.py
@@ -0,0 +1,316 @@
+# coding=utf-8
+# Copyright 2022 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import random
+import unittest
+
+import numpy as np
+import torch
+from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
+
+from diffusers import (
+ AutoencoderKL,
+ PNDMScheduler,
+ StableDiffusionAdapterPipeline,
+ T2IAdapter,
+ UNet2DConditionModel,
+)
+from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu
+
+from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
+from ..test_pipelines_common import PipelineTesterMixin
+
+
+enable_full_determinism()
+
+
+class AdapterTests:
+ pipeline_class = StableDiffusionAdapterPipeline
+ params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS
+ batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
+
+ def get_dummy_components(self, adapter_type):
+ torch.manual_seed(0)
+ unet = UNet2DConditionModel(
+ block_out_channels=(32, 64),
+ layers_per_block=2,
+ sample_size=32,
+ in_channels=4,
+ out_channels=4,
+ down_block_types=("CrossAttnDownBlock2D", "DownBlock2D"),
+ up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
+ cross_attention_dim=32,
+ )
+ scheduler = PNDMScheduler(skip_prk_steps=True)
+ torch.manual_seed(0)
+ vae = AutoencoderKL(
+ block_out_channels=[32, 64],
+ in_channels=3,
+ out_channels=3,
+ down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
+ up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
+ latent_channels=4,
+ )
+ torch.manual_seed(0)
+ text_encoder_config = CLIPTextConfig(
+ bos_token_id=0,
+ eos_token_id=2,
+ hidden_size=32,
+ intermediate_size=37,
+ layer_norm_eps=1e-05,
+ num_attention_heads=4,
+ num_hidden_layers=5,
+ pad_token_id=1,
+ vocab_size=1000,
+ )
+ text_encoder = CLIPTextModel(text_encoder_config)
+ tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+
+ torch.manual_seed(0)
+ adapter = T2IAdapter(
+ in_channels=3,
+ channels=[32, 64],
+ num_res_blocks=2,
+ downscale_factor=2,
+ adapter_type=adapter_type,
+ )
+
+ components = {
+ "adapter": adapter,
+ "unet": unet,
+ "scheduler": scheduler,
+ "vae": vae,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "safety_checker": None,
+ "feature_extractor": None,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device)
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "image": image,
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 6.0,
+ "output_type": "numpy",
+ }
+ return inputs
+
+ def test_attention_slicing_forward_pass(self):
+ return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3)
+
+ @unittest.skipIf(
+ torch_device != "cuda" or not is_xformers_available(),
+ reason="XFormers attention is only available with CUDA and `xformers` installed",
+ )
+ def test_xformers_attention_forwardGenerator_pass(self):
+ self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=2e-3)
+
+ def test_inference_batch_single_identical(self):
+ self._test_inference_batch_single_identical(expected_max_diff=2e-3)
+
+
+class StableDiffusionFullAdapterPipelineFastTests(AdapterTests, PipelineTesterMixin, unittest.TestCase):
+ def get_dummy_components(self):
+ return super().get_dummy_components("full_adapter")
+
+ def test_stable_diffusion_adapter_default_case(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ components = self.get_dummy_components()
+ sd_pipe = StableDiffusionAdapterPipeline(**components)
+ sd_pipe = sd_pipe.to(device)
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = sd_pipe(**inputs).images
+ image_slice = image[0, -3:, -3:, -1]
+
+ assert image.shape == (1, 64, 64, 3)
+ expected_slice = np.array([0.4858, 0.5500, 0.4278, 0.4669, 0.6184, 0.4322, 0.5010, 0.5033, 0.4746])
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-3
+
+
+class StableDiffusionLightAdapterPipelineFastTests(AdapterTests, PipelineTesterMixin, unittest.TestCase):
+ def get_dummy_components(self):
+ return super().get_dummy_components("light_adapter")
+
+ def test_stable_diffusion_adapter_default_case(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ components = self.get_dummy_components()
+ sd_pipe = StableDiffusionAdapterPipeline(**components)
+ sd_pipe = sd_pipe.to(device)
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = sd_pipe(**inputs).images
+ image_slice = image[0, -3:, -3:, -1]
+
+ assert image.shape == (1, 64, 64, 3)
+ expected_slice = np.array([0.4965, 0.5548, 0.4330, 0.4771, 0.6226, 0.4382, 0.5037, 0.5071, 0.4782])
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-3
+
+
+@slow
+@require_torch_gpu
+class StableDiffusionAdapterPipelineSlowTests(unittest.TestCase):
+ def tearDown(self):
+ super().tearDown()
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def test_stable_diffusion_adapter(self):
+ test_cases = [
+ (
+ "TencentARC/t2iadapter_color_sd14v1",
+ "CompVis/stable-diffusion-v1-4",
+ "snail",
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/color.png",
+ 3,
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_color_sd14v1.npy",
+ ),
+ (
+ "TencentARC/t2iadapter_depth_sd14v1",
+ "CompVis/stable-diffusion-v1-4",
+ "desk",
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/desk_depth.png",
+ 3,
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_depth_sd14v1.npy",
+ ),
+ (
+ "TencentARC/t2iadapter_depth_sd15v2",
+ "runwayml/stable-diffusion-v1-5",
+ "desk",
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/desk_depth.png",
+ 3,
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_depth_sd15v2.npy",
+ ),
+ (
+ "TencentARC/t2iadapter_keypose_sd14v1",
+ "CompVis/stable-diffusion-v1-4",
+ "person",
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/person_keypose.png",
+ 3,
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_keypose_sd14v1.npy",
+ ),
+ (
+ "TencentARC/t2iadapter_openpose_sd14v1",
+ "CompVis/stable-diffusion-v1-4",
+ "person",
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/iron_man_pose.png",
+ 3,
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_openpose_sd14v1.npy",
+ ),
+ (
+ "TencentARC/t2iadapter_seg_sd14v1",
+ "CompVis/stable-diffusion-v1-4",
+ "motorcycle",
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/motor.png",
+ 3,
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_seg_sd14v1.npy",
+ ),
+ (
+ "TencentARC/t2iadapter_zoedepth_sd15v1",
+ "runwayml/stable-diffusion-v1-5",
+ "motorcycle",
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/motorcycle.png",
+ 3,
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_zoedepth_sd15v1.npy",
+ ),
+ (
+ "TencentARC/t2iadapter_canny_sd14v1",
+ "CompVis/stable-diffusion-v1-4",
+ "toy",
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/toy_canny.png",
+ 1,
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_canny_sd14v1.npy",
+ ),
+ (
+ "TencentARC/t2iadapter_canny_sd15v2",
+ "runwayml/stable-diffusion-v1-5",
+ "toy",
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/toy_canny.png",
+ 1,
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_canny_sd15v2.npy",
+ ),
+ (
+ "TencentARC/t2iadapter_sketch_sd14v1",
+ "CompVis/stable-diffusion-v1-4",
+ "cat",
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/edge.png",
+ 1,
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_sketch_sd14v1.npy",
+ ),
+ (
+ "TencentARC/t2iadapter_sketch_sd15v2",
+ "runwayml/stable-diffusion-v1-5",
+ "cat",
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/edge.png",
+ 1,
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_sketch_sd15v2.npy",
+ ),
+ ]
+
+ for adapter_model, sd_model, prompt, image_url, input_channels, out_url in test_cases:
+ image = load_image(image_url)
+ expected_out = load_numpy(out_url)
+
+ if input_channels == 1:
+ image = image.convert("L")
+
+ adapter = T2IAdapter.from_pretrained(adapter_model, torch_dtype=torch.float16)
+
+ pipe = StableDiffusionAdapterPipeline.from_pretrained(sd_model, adapter=adapter, safety_checker=None)
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ pipe.enable_attention_slicing()
+
+ generator = torch.Generator(device="cpu").manual_seed(0)
+
+ out = pipe(prompt=prompt, image=image, generator=generator, num_inference_steps=2, output_type="np").images
+
+ self.assertTrue(np.allclose(out, expected_out))
+
+ def test_stable_diffusion_adapter_pipeline_with_sequential_cpu_offloading(self):
+ torch.cuda.empty_cache()
+ torch.cuda.reset_max_memory_allocated()
+ torch.cuda.reset_peak_memory_stats()
+
+ adapter = T2IAdapter.from_pretrained("TencentARC/t2iadapter_seg_sd14v1")
+ pipe = StableDiffusionAdapterPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4", adapter=adapter, safety_checker=None
+ )
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ pipe.enable_attention_slicing(1)
+ pipe.enable_sequential_cpu_offload()
+
+ image = load_image(
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/motor.png"
+ )
+
+ pipe(prompt="foo", image=image, num_inference_steps=2)
+
+ mem_bytes = torch.cuda.max_memory_allocated()
+ assert mem_bytes < 5 * 10**9