From 1d0b1a50c311618c13303d8cc299b97db98ae6f7 Mon Sep 17 00:00:00 2001 From: HimariO Date: Sun, 19 Feb 2023 04:07:50 +0800 Subject: [PATCH 01/10] Quick implementation of t2i-adapter Load adapter module with from_pretrained Prototyping generalized adapter framework Writeup doc string for sideload framework(WIP) + some minor update on implementation Update adapter models Remove old adapter optional args in UNet Add StableDiffusionAdapterPipeline unit test Handle cpu offload in StableDiffusionAdapterPipeline Auto correct coding style Update model repo name to "RzZ/sd-v1-4-adapter-pipeline" Refactor MultiAdapter to better compatible with config system Export MultiAdapter Create pipeline document template from controlnet Create dummy objects Supproting new AdapterLight model Fix StableDiffusionAdapterPipeline common pipeline test [WIP] Update adapter pipeline document Handle num_inference_steps in StableDiffusionAdapterPipeline Update definition of Adapter "channels_in" Update documents Apply code style Fix doc typo and merge error Update doc string and example Quality of life improvement Remove redundant code and file from prototyping Remove unused pageage Remove comments Fix title Fix typo Add conditioning scale arg Bring back old implmentation Offload sideload Add supply info on document Update src/diffusers/models/adapter.py Co-authored-by: Will Berman Update MultiAdapter constructor Swap out custom checkpoint and update pipeline constructor Update docment Apply suggestions from code review Co-authored-by: Will Berman Correcting style Following single-file policy Update auto size in image preprocess func Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_adapter.py Co-authored-by: Will Berman fix copies Update adapter pipeline behavior Add adapter_conditioning_scale doc string Add the missing doc string Apply suggestions from code review Co-authored-by: Patrick von Platen Fix few bugs from suggestion Handle L-mode PIL image as control image Rename to differentiate adapter resblock Update src/diffusers/models/adapter.py Co-authored-by: Sayak Paul Fix typo Update adapter parameter name Update test case and code style Fix copies Fix typo Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_adapter.py Co-authored-by: Will Berman Update Adapter class name Add checkpoint converting script Fix style Fix-copies Remove dev script Apply suggestions from code review Co-authored-by: Patrick von Platen Updates for parameter rename Fix convert_adapter remove main fix diff more refactoring more more small fixes refactor tests more slow tests more tests Update docs/source/en/api/pipelines/overview.mdx Co-authored-by: Sayak Paul add community contributor to docs Update docs/source/en/api/pipelines/stable_diffusion/adapter.mdx Co-authored-by: Sayak Paul Update docs/source/en/api/pipelines/stable_diffusion/adapter.mdx Co-authored-by: Sayak Paul Update docs/source/en/api/pipelines/stable_diffusion/adapter.mdx Co-authored-by: Sayak Paul Update docs/source/en/api/pipelines/stable_diffusion/adapter.mdx Co-authored-by: Sayak Paul Update docs/source/en/api/pipelines/stable_diffusion/adapter.mdx Co-authored-by: Sayak Paul fix remove from_adapters license paper link docs more url fixes more docs fix fixes fix fix --- docs/source/en/_toctree.yml | 2 + docs/source/en/api/pipelines/overview.mdx | 1 + .../pipelines/stable_diffusion/adapter.mdx | 187 ++++ docs/source/en/index.mdx | 1 + .../controlling_generation.mdx | 12 +- scripts/convert_original_t2i_adapter.py | 250 ++++++ src/diffusers/__init__.py | 3 + src/diffusers/models/__init__.py | 1 + src/diffusers/models/adapter.py | 291 +++++++ src/diffusers/models/unet_2d_blocks.py | 5 + src/diffusers/models/unet_2d_condition.py | 18 +- src/diffusers/pipelines/__init__.py | 1 + .../pipelines/stable_diffusion/__init__.py | 1 + .../pipeline_stable_diffusion_adapter.py | 799 ++++++++++++++++++ .../versatile_diffusion/modeling_text_unet.py | 23 +- src/diffusers/utils/dummy_pt_objects.py | 30 + .../dummy_torch_and_transformers_objects.py | 15 + .../test_stable_diffusion_adapter.py | 316 +++++++ 18 files changed, 1949 insertions(+), 7 deletions(-) create mode 100644 docs/source/en/api/pipelines/stable_diffusion/adapter.mdx create mode 100644 scripts/convert_original_t2i_adapter.py create mode 100644 src/diffusers/models/adapter.py create mode 100644 src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_adapter.py create mode 100644 tests/pipelines/stable_diffusion/test_stable_diffusion_adapter.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index f695dd3c1df9..f477f288bbe0 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -255,6 +255,8 @@ title: Super-Resolution - local: api/pipelines/stable_diffusion/ldm3d_diffusion title: LDM3D Text-to-(RGB, Depth) + - local: api/pipelines/stable_diffusion/adapter + title: Stable Diffusion T2I-adapter title: Stable Diffusion - local: api/pipelines/stable_unclip title: Stable unCLIP diff --git a/docs/source/en/api/pipelines/overview.mdx b/docs/source/en/api/pipelines/overview.mdx index 693c32565c46..1d61ae6a1314 100644 --- a/docs/source/en/api/pipelines/overview.mdx +++ b/docs/source/en/api/pipelines/overview.mdx @@ -66,6 +66,7 @@ available a colab notebook to directly try them out. | [score_sde_ve](./score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation | | [score_sde_vp](./score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation | | [semantic_stable_diffusion](./semantic_stable_diffusion) | [**SEGA: Instructing Diffusion using Semantic Dimensions**](https://arxiv.org/abs/2301.12247) | Text-to-Image Generation | +| [stable_diffusion_adapter](./stable_diffusion/adapter) | [**T2I-Adapter**](https://arxiv.org/abs/2302.08453) | Image-to-Image Text-Guided Generation with Adapters | - | [stable_diffusion_text2img](./stable_diffusion/text2img) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) | [stable_diffusion_img2img](./stable_diffusion/img2img) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb) | [stable_diffusion_inpaint](./stable_diffusion/inpaint) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb) diff --git a/docs/source/en/api/pipelines/stable_diffusion/adapter.mdx b/docs/source/en/api/pipelines/stable_diffusion/adapter.mdx new file mode 100644 index 000000000000..6e07e55250e6 --- /dev/null +++ b/docs/source/en/api/pipelines/stable_diffusion/adapter.mdx @@ -0,0 +1,187 @@ + + +# Text-to-Image Generation with Adapter Conditioning + +## Overview + +[T2I-Adapter: Learning Adapters to Dig out More Controllable Ability for Text-to-Image Diffusion Models](https://arxiv.org/abs/2302.08453) by Chong Mou, Xintao Wang, Liangbin Xie, Jian Zhang, Zhongang Qi, Ying Shan, Xiaohu Qie. + +Using the pretrained models we can provide control images (for example, a depth map) to control Stable Diffusion text-to-image generation so that it follows the structure of the depth image and fills in the details. + +The abstract of the paper is the following: + +*The incredible generative ability of large-scale text-to-image (T2I) models has demonstrated strong power of learning complex structures and meaningful semantics. However, relying solely on text prompts cannot fully take advantage of the knowledge learned by the model, especially when flexible and accurate structure control is needed. In this paper, we aim to ``dig out" the capabilities that T2I models have implicitly learned, and then explicitly use them to control the generation more granularly. Specifically, we propose to learn simple and small T2I-Adapters to align internal knowledge in T2I models with external control signals, while freezing the original large T2I models. In this way, we can train various adapters according to different conditions, and achieve rich control and editing effects. Further, the proposed T2I-Adapters have attractive properties of practical value, such as composability and generalization ability. Extensive experiments demonstrate that our T2I-Adapter has promising generation quality and a wide range of applications.* + +This model was contributed by the community contributor [HimariO](https://github.com/HimariO) ❤️ . + +## Available Pipelines: + +| Pipeline | Tasks | Demo +|---|---|:---:| +| [StableDiffusionAdapterPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_adapter.py) | *Text-to-Image Generation with T2I-Adapter Conditioning* | - + +## Usage example + +In the following we give a simple example of how to use a *T2IAdapter* checkpoint with Diffusers for inference. +All adapters use the same pipeline. + + 1. Images are first converted into the appropriate *control image* format. + 2. The *control image* and *prompt* are passed to the [`StableDiffusionAdapterPipeline`]. + +Let's have a look at a simple example using the [Color Adapter](https://huggingface.co/diffusers/t2iadapter_color_sd14v1). + +```python +from diffusers.utils import load_image + +image = load_image("https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/color_ref.png") +``` + +![img](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/color_ref.png) + + +Then we can create our color palette by simply resizing it to 8 by 8 pixels and then scaling it back to original size. + +```python +from PIL import Image + +color_palette = image.resize((8, 8)) +color_palette = color_palette.resize((512, 512), resample=Image.Resampling.NEAREST) +``` + +Let's take a look at the processed image. + +![img](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/color_palette.png) + + +Next, create the adapter pipeline + +```py +import torch +from diffusers import StableDiffusionAdapterPipeline, T2IAdapter + +adapter = T2IAdapter.from_pretrained("diffusers/t2iadapter_color_sd14v1") +pipe = StableDiffusionAdapterPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", + adapter=adapter, + torch_dtype=torch.float16, +) +pipe.to("cuda") +``` + +Finally, pass the prompt and control image to the pipeline + +```py +# fix the random seed, so you will get the same result as the example +generator = torch.manual_seed(7) + +out_image = pipe( + "At night, glowing cubes in front of the beach", + image=color_palette, + generator=generator, +).images[0] +``` + +![img](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/color_output.png) + + +## Available checkpoints + +Non-diffusers checkpoints can be found under [TencentARC/T2I-Adapter](https://huggingface.co/TencentARC/T2I-Adapter/tree/main/models). + +### T2I-Adapter with Stable Diffusion 1.4 + +| Model Name | Control Image Overview| Control Image Example | Generated Image Example | +|---|---|---|---| +|[diffusers/t2iadapter_color_sd14v1](https://huggingface.co/diffusers/t2iadapter_color_sd14v1)
*Trained with spatial color palette* | A image with 8x8 color palette.||| +|[diffusers/t2iadapter_canny_sd14v1](https://huggingface.co/diffusers/t2iadapter_canny_sd14v1)
*Trained with canny edge detection* | A monochrome image with white edges on a black background.||| +|[diffusers/t2iadapter_sketch_sd14v1](https://huggingface.co/diffusers/t2iadapter_sketch_sd14v1)
*Trained with [PidiNet](https://github.com/zhuoinoulu/pidinet) edge detection* | A hand-drawn monochrome image with white outlines on a black background.||| +|[diffusers/t2iadapter_depth_sd14v1](https://huggingface.co/diffusers/t2iadapter_depth_sd14v1)
*Trained with Midas depth estimation* | A grayscale image with black representing deep areas and white representing shallow areas.||| +|[diffusers/t2iadapter_openpose_sd14v1](https://huggingface.co/diffusers/t2iadapter_openpose_sd14v1)
*Trained with OpenPose bone image* | A [OpenPose bone](https://github.com/CMU-Perceptual-Computing-Lab/openpose) image.||| +|[diffusers/t2iadapter_keypose_sd14v1](https://huggingface.co/diffusers/t2iadapter_keypose_sd14v1)
*Trained with mmpose skeleton image* | A [mmpose skeleton](https://github.com/open-mmlab/mmpose) image.||| +|[diffusers/t2iadapter_seg_sd14v1](https://huggingface.co/diffusers/t2iadapter_seg_sd14v1)
*Trained with semantic segmentation* | An [custom](https://github.com/TencentARC/T2I-Adapter/discussions/25) segmentation protocol image.|| | +|[diffusers/t2iadapter_canny_sd15v2](https://huggingface.co/diffusers/t2iadapter_canny_sd15v2)|| +|[diffusers/t2iadapter_depth_sd15v2](https://huggingface.co/diffusers/t2iadapter_depth_sd15v2)|| +|[diffusers/t2iadapter_sketch_sd15v2](https://huggingface.co/diffusers/t2iadapter_sketch_sd15v2)|| +|[diffusers/t2iadapter_zoedepth_sd15v1](https://huggingface.co/diffusers/t2iadapter_zoedepth_sd15v1)|| + +## Combining multiple adapters + +[`MultiAdapter`] can be used for applying multiple conditionings at once. + +Here we use the keypose adapter for the character posture and the depth adapter for creating the scene. + +```py +import torch +from PIL import Image +from diffusers.utils import load_image + +cond_keypose = load_image( + "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/keypose_sample_input.png" +) +cond_depth = load_image( + "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/depth_sample_input.png" +) +cond = [[cond_keypose, cond_depth]] + +prompt = ["A man walking in an office room with a nice view"] +``` + +The two control images look as such: + +![img](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/keypose_sample_input.png) +![img](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/depth_sample_input.png) + + +`MultiAdapter` combines keypose and depth adapters. + +`adapter_conditioning_scale` balances the relative influence of the different adapters. + +```py +from diffusers import StableDiffusionAdapterPipeline, MultiAdapter + +adapters = MultiAdapter( + [ + T2IAdapter.from_pretrained("diffusers/t2iadapter_keypose_sd14v1"), + T2IAdapter.from_pretrained("diffusers/t2iadapter_depth_sd14v1"), + ] +) +adapters = adapters.to(torch.float16) + +pipe = StableDiffusionAdapterPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", + torch_dtype=torch.float16, + adapter=adapters, +) + +images = pipe(prompt, cond, adapter_conditioning_scale=[0.8, 0.8]) +``` + +![img](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/keypose_depth_sample_output.png) + + +## T2I Adapter vs ControlNet + +T2I-Adapter is similar to [ControlNet](https://huggingface.co/docs/diffusers/main/en/api/pipelines/controlnet). +T2i-Adapter uses a smaller auxiliary network which is only run once for the entire diffusion process. +However, T2I-Adapter performs slightly worse than ControlNet. + +## StableDiffusionAdapterPipeline +[[autodoc]] StableDiffusionAdapterPipeline + - all + - __call__ + - enable_attention_slicing + - disable_attention_slicing + - enable_vae_slicing + - disable_vae_slicing + - enable_xformers_memory_efficient_attention + - disable_xformers_memory_efficient_attention diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx index 1bb52e628a04..f2012abc6970 100644 --- a/docs/source/en/index.mdx +++ b/docs/source/en/index.mdx @@ -69,6 +69,7 @@ The library has three main components: | [score_sde_ve](./api/pipelines/score_sde_ve) | [Score-Based Generative Modeling through Stochastic Differential Equations](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation | | [score_sde_vp](./api/pipelines/score_sde_vp) | [Score-Based Generative Modeling through Stochastic Differential Equations](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation | | [semantic_stable_diffusion](./api/pipelines/semantic_stable_diffusion) | [Semantic Guidance](https://arxiv.org/abs/2301.12247) | Text-Guided Generation | +| [stable_diffusion_adapter](./api/pipelines/stable_diffusion/adapter) | [**T2I-Adapter**](https://arxiv.org/abs/2302.08453) | Image-to-Image Text-Guided Generation | - | [stable_diffusion_text2img](./api/pipelines/stable_diffusion/text2img) | [Stable Diffusion](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation | | [stable_diffusion_img2img](./api/pipelines/stable_diffusion/img2img) | [Stable Diffusion](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | | [stable_diffusion_inpaint](./api/pipelines/stable_diffusion/inpaint) | [Stable Diffusion](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | diff --git a/docs/source/en/using-diffusers/controlling_generation.mdx b/docs/source/en/using-diffusers/controlling_generation.mdx index 57b5640ffcd5..2660903517eb 100644 --- a/docs/source/en/using-diffusers/controlling_generation.mdx +++ b/docs/source/en/using-diffusers/controlling_generation.mdx @@ -59,6 +59,7 @@ For convenience, we provide a table to denote which methods are inference-only a | [Custom Diffusion](#custom-diffusion) | ❌ | ✅ | | | [Model Editing](#model-editing) | ✅ | ❌ | | | [DiffEdit](#diffedit) | ✅ | ❌ | | +| [T2I-Adapter](#t2i-adapter) | ✅ | ❌ | | ## Instruct Pix2Pix @@ -215,4 +216,13 @@ To know more details, check out the [official doc](../api/pipelines/stable_diffu [DiffEdit](../api/pipelines/stable_diffusion/diffedit) allows for semantic editing of input images along with input prompts while preserving the original input images as much as possible. -To know more details, check out the [official doc](../api/pipelines/stable_diffusion/model_editing). \ No newline at end of file +To know more details, check out the [official doc](../api/pipelines/stable_diffusion/model_editing). +## T2I-Adapter + +[Paper](https://arxiv.org/abs/2302.08453) + +[T2I-Adapter](../api/pipelines/stable_diffusion/adapter) is an auxiliary network which adds an extra condition. +There are 8 canonical pre-trained adapters trained on different conditionings such as edge detection, sketch, +depth maps, and semantic segmentations. + +See [here](../api/pipelines/stable_diffusion/adapter) for more information on how to use it. diff --git a/scripts/convert_original_t2i_adapter.py b/scripts/convert_original_t2i_adapter.py new file mode 100644 index 000000000000..01a1fecf4e4b --- /dev/null +++ b/scripts/convert_original_t2i_adapter.py @@ -0,0 +1,250 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Conversion script for the T2I-Adapter checkpoints. +""" + +import argparse + +import torch + +from diffusers import T2IAdapter + + +def convert_adapter(src_state, in_channels): + original_body_length = max([int(x.split(".")[1]) for x in src_state.keys() if "body." in x]) + 1 + + assert original_body_length == 8 + + # (0, 1) -> channels 1 + assert src_state["body.0.block1.weight"].shape == (320, 320, 3, 3) + + # (2, 3) -> channels 2 + assert src_state["body.2.in_conv.weight"].shape == (640, 320, 1, 1) + + # (4, 5) -> channels 3 + assert src_state["body.4.in_conv.weight"].shape == (1280, 640, 1, 1) + + # (6, 7) -> channels 4 + assert src_state["body.6.block1.weight"].shape == (1280, 1280, 3, 3) + + res_state = { + "adapter.conv_in.weight": src_state.pop("conv_in.weight"), + "adapter.conv_in.bias": src_state.pop("conv_in.bias"), + # 0.resnets.0 + "adapter.body.0.resnets.0.block1.weight": src_state.pop("body.0.block1.weight"), + "adapter.body.0.resnets.0.block1.bias": src_state.pop("body.0.block1.bias"), + "adapter.body.0.resnets.0.block2.weight": src_state.pop("body.0.block2.weight"), + "adapter.body.0.resnets.0.block2.bias": src_state.pop("body.0.block2.bias"), + # 0.resnets.1 + "adapter.body.0.resnets.1.block1.weight": src_state.pop("body.1.block1.weight"), + "adapter.body.0.resnets.1.block1.bias": src_state.pop("body.1.block1.bias"), + "adapter.body.0.resnets.1.block2.weight": src_state.pop("body.1.block2.weight"), + "adapter.body.0.resnets.1.block2.bias": src_state.pop("body.1.block2.bias"), + # 1 + "adapter.body.1.in_conv.weight": src_state.pop("body.2.in_conv.weight"), + "adapter.body.1.in_conv.bias": src_state.pop("body.2.in_conv.bias"), + # 1.resnets.0 + "adapter.body.1.resnets.0.block1.weight": src_state.pop("body.2.block1.weight"), + "adapter.body.1.resnets.0.block1.bias": src_state.pop("body.2.block1.bias"), + "adapter.body.1.resnets.0.block2.weight": src_state.pop("body.2.block2.weight"), + "adapter.body.1.resnets.0.block2.bias": src_state.pop("body.2.block2.bias"), + # 1.resnets.1 + "adapter.body.1.resnets.1.block1.weight": src_state.pop("body.3.block1.weight"), + "adapter.body.1.resnets.1.block1.bias": src_state.pop("body.3.block1.bias"), + "adapter.body.1.resnets.1.block2.weight": src_state.pop("body.3.block2.weight"), + "adapter.body.1.resnets.1.block2.bias": src_state.pop("body.3.block2.bias"), + # 2 + "adapter.body.2.in_conv.weight": src_state.pop("body.4.in_conv.weight"), + "adapter.body.2.in_conv.bias": src_state.pop("body.4.in_conv.bias"), + # 2.resnets.0 + "adapter.body.2.resnets.0.block1.weight": src_state.pop("body.4.block1.weight"), + "adapter.body.2.resnets.0.block1.bias": src_state.pop("body.4.block1.bias"), + "adapter.body.2.resnets.0.block2.weight": src_state.pop("body.4.block2.weight"), + "adapter.body.2.resnets.0.block2.bias": src_state.pop("body.4.block2.bias"), + # 2.resnets.1 + "adapter.body.2.resnets.1.block1.weight": src_state.pop("body.5.block1.weight"), + "adapter.body.2.resnets.1.block1.bias": src_state.pop("body.5.block1.bias"), + "adapter.body.2.resnets.1.block2.weight": src_state.pop("body.5.block2.weight"), + "adapter.body.2.resnets.1.block2.bias": src_state.pop("body.5.block2.bias"), + # 3.resnets.0 + "adapter.body.3.resnets.0.block1.weight": src_state.pop("body.6.block1.weight"), + "adapter.body.3.resnets.0.block1.bias": src_state.pop("body.6.block1.bias"), + "adapter.body.3.resnets.0.block2.weight": src_state.pop("body.6.block2.weight"), + "adapter.body.3.resnets.0.block2.bias": src_state.pop("body.6.block2.bias"), + # 3.resnets.1 + "adapter.body.3.resnets.1.block1.weight": src_state.pop("body.7.block1.weight"), + "adapter.body.3.resnets.1.block1.bias": src_state.pop("body.7.block1.bias"), + "adapter.body.3.resnets.1.block2.weight": src_state.pop("body.7.block2.weight"), + "adapter.body.3.resnets.1.block2.bias": src_state.pop("body.7.block2.bias"), + } + + assert len(src_state) == 0 + + adapter = T2IAdapter(in_channels=in_channels, adapter_type="full_adapter") + + adapter.load_state_dict(res_state) + + return adapter + + +def convert_light_adapter(src_state): + original_body_length = max([int(x.split(".")[1]) for x in src_state.keys() if "body." in x]) + 1 + + assert original_body_length == 4 + + res_state = { + # body.0.in_conv + "adapter.body.0.in_conv.weight": src_state.pop("body.0.in_conv.weight"), + "adapter.body.0.in_conv.bias": src_state.pop("body.0.in_conv.bias"), + # body.0.resnets.0 + "adapter.body.0.resnets.0.block1.weight": src_state.pop("body.0.body.0.block1.weight"), + "adapter.body.0.resnets.0.block1.bias": src_state.pop("body.0.body.0.block1.bias"), + "adapter.body.0.resnets.0.block2.weight": src_state.pop("body.0.body.0.block2.weight"), + "adapter.body.0.resnets.0.block2.bias": src_state.pop("body.0.body.0.block2.bias"), + # body.0.resnets.1 + "adapter.body.0.resnets.1.block1.weight": src_state.pop("body.0.body.1.block1.weight"), + "adapter.body.0.resnets.1.block1.bias": src_state.pop("body.0.body.1.block1.bias"), + "adapter.body.0.resnets.1.block2.weight": src_state.pop("body.0.body.1.block2.weight"), + "adapter.body.0.resnets.1.block2.bias": src_state.pop("body.0.body.1.block2.bias"), + # body.0.resnets.2 + "adapter.body.0.resnets.2.block1.weight": src_state.pop("body.0.body.2.block1.weight"), + "adapter.body.0.resnets.2.block1.bias": src_state.pop("body.0.body.2.block1.bias"), + "adapter.body.0.resnets.2.block2.weight": src_state.pop("body.0.body.2.block2.weight"), + "adapter.body.0.resnets.2.block2.bias": src_state.pop("body.0.body.2.block2.bias"), + # body.0.resnets.3 + "adapter.body.0.resnets.3.block1.weight": src_state.pop("body.0.body.3.block1.weight"), + "adapter.body.0.resnets.3.block1.bias": src_state.pop("body.0.body.3.block1.bias"), + "adapter.body.0.resnets.3.block2.weight": src_state.pop("body.0.body.3.block2.weight"), + "adapter.body.0.resnets.3.block2.bias": src_state.pop("body.0.body.3.block2.bias"), + # body.0.out_conv + "adapter.body.0.out_conv.weight": src_state.pop("body.0.out_conv.weight"), + "adapter.body.0.out_conv.bias": src_state.pop("body.0.out_conv.bias"), + # body.1.in_conv + "adapter.body.1.in_conv.weight": src_state.pop("body.1.in_conv.weight"), + "adapter.body.1.in_conv.bias": src_state.pop("body.1.in_conv.bias"), + # body.1.resnets.0 + "adapter.body.1.resnets.0.block1.weight": src_state.pop("body.1.body.0.block1.weight"), + "adapter.body.1.resnets.0.block1.bias": src_state.pop("body.1.body.0.block1.bias"), + "adapter.body.1.resnets.0.block2.weight": src_state.pop("body.1.body.0.block2.weight"), + "adapter.body.1.resnets.0.block2.bias": src_state.pop("body.1.body.0.block2.bias"), + # body.1.resnets.1 + "adapter.body.1.resnets.1.block1.weight": src_state.pop("body.1.body.1.block1.weight"), + "adapter.body.1.resnets.1.block1.bias": src_state.pop("body.1.body.1.block1.bias"), + "adapter.body.1.resnets.1.block2.weight": src_state.pop("body.1.body.1.block2.weight"), + "adapter.body.1.resnets.1.block2.bias": src_state.pop("body.1.body.1.block2.bias"), + # body.1.body.2 + "adapter.body.1.resnets.2.block1.weight": src_state.pop("body.1.body.2.block1.weight"), + "adapter.body.1.resnets.2.block1.bias": src_state.pop("body.1.body.2.block1.bias"), + "adapter.body.1.resnets.2.block2.weight": src_state.pop("body.1.body.2.block2.weight"), + "adapter.body.1.resnets.2.block2.bias": src_state.pop("body.1.body.2.block2.bias"), + # body.1.body.3 + "adapter.body.1.resnets.3.block1.weight": src_state.pop("body.1.body.3.block1.weight"), + "adapter.body.1.resnets.3.block1.bias": src_state.pop("body.1.body.3.block1.bias"), + "adapter.body.1.resnets.3.block2.weight": src_state.pop("body.1.body.3.block2.weight"), + "adapter.body.1.resnets.3.block2.bias": src_state.pop("body.1.body.3.block2.bias"), + # body.1.out_conv + "adapter.body.1.out_conv.weight": src_state.pop("body.1.out_conv.weight"), + "adapter.body.1.out_conv.bias": src_state.pop("body.1.out_conv.bias"), + # body.2.in_conv + "adapter.body.2.in_conv.weight": src_state.pop("body.2.in_conv.weight"), + "adapter.body.2.in_conv.bias": src_state.pop("body.2.in_conv.bias"), + # body.2.body.0 + "adapter.body.2.resnets.0.block1.weight": src_state.pop("body.2.body.0.block1.weight"), + "adapter.body.2.resnets.0.block1.bias": src_state.pop("body.2.body.0.block1.bias"), + "adapter.body.2.resnets.0.block2.weight": src_state.pop("body.2.body.0.block2.weight"), + "adapter.body.2.resnets.0.block2.bias": src_state.pop("body.2.body.0.block2.bias"), + # body.2.body.1 + "adapter.body.2.resnets.1.block1.weight": src_state.pop("body.2.body.1.block1.weight"), + "adapter.body.2.resnets.1.block1.bias": src_state.pop("body.2.body.1.block1.bias"), + "adapter.body.2.resnets.1.block2.weight": src_state.pop("body.2.body.1.block2.weight"), + "adapter.body.2.resnets.1.block2.bias": src_state.pop("body.2.body.1.block2.bias"), + # body.2.body.2 + "adapter.body.2.resnets.2.block1.weight": src_state.pop("body.2.body.2.block1.weight"), + "adapter.body.2.resnets.2.block1.bias": src_state.pop("body.2.body.2.block1.bias"), + "adapter.body.2.resnets.2.block2.weight": src_state.pop("body.2.body.2.block2.weight"), + "adapter.body.2.resnets.2.block2.bias": src_state.pop("body.2.body.2.block2.bias"), + # body.2.body.3 + "adapter.body.2.resnets.3.block1.weight": src_state.pop("body.2.body.3.block1.weight"), + "adapter.body.2.resnets.3.block1.bias": src_state.pop("body.2.body.3.block1.bias"), + "adapter.body.2.resnets.3.block2.weight": src_state.pop("body.2.body.3.block2.weight"), + "adapter.body.2.resnets.3.block2.bias": src_state.pop("body.2.body.3.block2.bias"), + # body.2.out_conv + "adapter.body.2.out_conv.weight": src_state.pop("body.2.out_conv.weight"), + "adapter.body.2.out_conv.bias": src_state.pop("body.2.out_conv.bias"), + # body.3.in_conv + "adapter.body.3.in_conv.weight": src_state.pop("body.3.in_conv.weight"), + "adapter.body.3.in_conv.bias": src_state.pop("body.3.in_conv.bias"), + # body.3.body.0 + "adapter.body.3.resnets.0.block1.weight": src_state.pop("body.3.body.0.block1.weight"), + "adapter.body.3.resnets.0.block1.bias": src_state.pop("body.3.body.0.block1.bias"), + "adapter.body.3.resnets.0.block2.weight": src_state.pop("body.3.body.0.block2.weight"), + "adapter.body.3.resnets.0.block2.bias": src_state.pop("body.3.body.0.block2.bias"), + # body.3.body.1 + "adapter.body.3.resnets.1.block1.weight": src_state.pop("body.3.body.1.block1.weight"), + "adapter.body.3.resnets.1.block1.bias": src_state.pop("body.3.body.1.block1.bias"), + "adapter.body.3.resnets.1.block2.weight": src_state.pop("body.3.body.1.block2.weight"), + "adapter.body.3.resnets.1.block2.bias": src_state.pop("body.3.body.1.block2.bias"), + # body.3.body.2 + "adapter.body.3.resnets.2.block1.weight": src_state.pop("body.3.body.2.block1.weight"), + "adapter.body.3.resnets.2.block1.bias": src_state.pop("body.3.body.2.block1.bias"), + "adapter.body.3.resnets.2.block2.weight": src_state.pop("body.3.body.2.block2.weight"), + "adapter.body.3.resnets.2.block2.bias": src_state.pop("body.3.body.2.block2.bias"), + # body.3.body.3 + "adapter.body.3.resnets.3.block1.weight": src_state.pop("body.3.body.3.block1.weight"), + "adapter.body.3.resnets.3.block1.bias": src_state.pop("body.3.body.3.block1.bias"), + "adapter.body.3.resnets.3.block2.weight": src_state.pop("body.3.body.3.block2.weight"), + "adapter.body.3.resnets.3.block2.bias": src_state.pop("body.3.body.3.block2.bias"), + # body.3.out_conv + "adapter.body.3.out_conv.weight": src_state.pop("body.3.out_conv.weight"), + "adapter.body.3.out_conv.bias": src_state.pop("body.3.out_conv.bias"), + } + + assert len(src_state) == 0 + + adapter = T2IAdapter(in_channels=3, channels=[320, 640, 1280], num_res_blocks=4, adapter_type="light_adapter") + + adapter.load_state_dict(res_state) + + return adapter + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." + ) + parser.add_argument( + "--output_path", default=None, type=str, required=True, help="Path to the store the result checkpoint." + ) + parser.add_argument( + "--is_adapter_light", + action="store_true", + help="Is checkpoint come from Adapter-Light architecture. ex: color-adapter", + ) + parser.add_argument("--in_channels", required=False, type=int, help="Input channels for non-light adapter") + + args = parser.parse_args() + src_state = torch.load(args.checkpoint_path) + + if args.is_adapter_light: + adapter = convert_light_adapter(src_state) + else: + if args.in_channels is None: + raise ValueError("set `--in_channels=`") + adapter = convert_adapter(src_state, args.in_channels) + + adapter.save_pretrained(args.output_path) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 39178edc00d1..0c0869fb52bd 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -39,7 +39,9 @@ AutoencoderKL, ControlNetModel, ModelMixin, + MultiAdapter, PriorTransformer, + T2IAdapter, T5FilmDecoder, Transformer2DModel, UNet1DModel, @@ -151,6 +153,7 @@ SemanticStableDiffusionPipeline, ShapEImg2ImgPipeline, ShapEPipeline, + StableDiffusionAdapterPipeline, StableDiffusionAttendAndExcitePipeline, StableDiffusionControlNetImg2ImgPipeline, StableDiffusionControlNetInpaintPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 23839c84af45..6e330a44691a 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -16,6 +16,7 @@ if is_torch_available(): + from .adapter import MultiAdapter, T2IAdapter from .autoencoder_kl import AutoencoderKL from .controlnet import ControlNetModel from .dual_transformer_2d import DualTransformer2DModel diff --git a/src/diffusers/models/adapter.py b/src/diffusers/models/adapter.py new file mode 100644 index 000000000000..a65a3873b130 --- /dev/null +++ b/src/diffusers/models/adapter.py @@ -0,0 +1,291 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional + +import torch +import torch.nn as nn + +from ..configuration_utils import ConfigMixin, register_to_config +from .modeling_utils import ModelMixin +from .resnet import Downsample2D + + +class MultiAdapter(ModelMixin): + r""" + MultiAdapter is a wrapper model that contains multiple adapter models and merges their outputs according to + user-assigned weighting. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the model (such as downloading or saving, etc.) + + Parameters: + adapters (`List[T2IAdapter]`, *optional*, defaults to None): + A list of `T2IAdapter` model instances. + """ + + def __init__(self, adapters: List["T2IAdapter"]): + super(MultiAdapter, self).__init__() + + self.num_adapter = len(adapters) + self.adapters = nn.ModuleList(adapters) + + def forward(self, xs: torch.Tensor, adapter_weights: Optional[List[float]] = None) -> List[torch.Tensor]: + r""" + Args: + xs (`torch.Tensor`): + (batch, channel, height, width) input images for multiple adapter models concated along dimension 1, + `channel` should equal to `num_adapter` * "number of channel of image". + adapter_weights (`List[float]`, *optional*, defaults to None): + List of floats representing the weight which will be multiply to each adapter's output before adding + them together. + """ + if adapter_weights is None: + adapter_weights = torch.tensor([1 / self.num_adapter] * self.num_adapter) + else: + adapter_weights = torch.tensor(adapter_weights) + + if xs.shape[1] % self.num_adapter != 0: + raise ValueError( + f"Expecting multi-adapter's input have number of channel that cab be evenly divisible " + f"by num_adapter: {xs.shape[1]} % {self.num_adapter} != 0" + ) + x_list = torch.chunk(xs, self.num_adapter, dim=1) + accume_state = None + for x, w, adapter in zip(x_list, adapter_weights, self.adapters): + features = adapter(x) + if accume_state is None: + accume_state = features + else: + for i in range(len(features)): + accume_state[i] += w * features[i] + return accume_state + + +class T2IAdapter(ModelMixin, ConfigMixin): + r""" + A simple ResNet-like model that accepts images containing control signals such as keyposes and depth. The model + generates multiple feature maps that are used as additional conditioning in [`UNet2DConditionModel`]. The model's + architecture follows the original implementation of + [Adapter](https://github.com/TencentARC/T2I-Adapter/blob/686de4681515662c0ac2ffa07bf5dda83af1038a/ldm/modules/encoders/adapter.py#L97) + and + [AdapterLight](https://github.com/TencentARC/T2I-Adapter/blob/686de4681515662c0ac2ffa07bf5dda83af1038a/ldm/modules/encoders/adapter.py#L235). + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the model (such as downloading or saving, etc.) + + Parameters: + in_channels (`int`, *optional*, defaults to 3): + Number of channels of Aapter's input(*control image*). Set this parameter to 1 if you're using gray scale + image as *control image*. + channels (`List[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The number of channel of each downsample block's output hidden state. The `len(block_out_channels)` will + also determine the number of downsample blocks in the Adapter. + num_res_blocks (`int`, *optional*, defaults to 2): + Number of ResNet blocks in each downsample block + """ + + @register_to_config + def __init__( + self, + in_channels: int = 3, + channels: List[int] = [320, 640, 1280, 1280], + num_res_blocks: int = 2, + downscale_factor: int = 8, + adapter_type: str = "full_adapter", + ): + super().__init__() + + if adapter_type == "full_adapter": + self.adapter = FullAdapter(in_channels, channels, num_res_blocks, downscale_factor) + elif adapter_type == "light_adapter": + self.adapter = LightAdapter(in_channels, channels, num_res_blocks, downscale_factor) + else: + raise ValueError(f"unknown adapter_type: {type}. Choose either 'full_adapter' or 'simple_adapter'") + + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + return self.adapter(x) + + @property + def total_downscale_factor(self): + return self.adapter.total_downscale_factor + + +# full adapter + + +class FullAdapter(nn.Module): + def __init__( + self, + in_channels: int = 3, + channels: List[int] = [320, 640, 1280, 1280], + num_res_blocks: int = 2, + downscale_factor: int = 8, + ): + super().__init__() + + in_channels = in_channels * downscale_factor**2 + + self.unshuffle = nn.PixelUnshuffle(downscale_factor) + self.conv_in = nn.Conv2d(in_channels, channels[0], kernel_size=3, padding=1) + + self.body = nn.ModuleList( + [ + AdapterBlock(channels[0], channels[0], num_res_blocks), + *[ + AdapterBlock(channels[i - 1], channels[i], num_res_blocks, down=True) + for i in range(1, len(channels)) + ], + ] + ) + + self.total_downscale_factor = downscale_factor * 2 ** (len(channels) - 1) + + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + x = self.unshuffle(x) + x = self.conv_in(x) + + features = [] + + for block in self.body: + x = block(x) + features.append(x) + + return features + + +class AdapterBlock(nn.Module): + def __init__(self, in_channels, out_channels, num_res_blocks, down=False): + super().__init__() + + self.downsample = None + if down: + self.downsample = Downsample2D(in_channels) + + self.in_conv = None + if in_channels != out_channels: + self.in_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) + + self.resnets = nn.Sequential( + *[AdapterResnetBlock(out_channels) for _ in range(num_res_blocks)], + ) + + def forward(self, x): + if self.downsample is not None: + x = self.downsample(x) + + if self.in_conv is not None: + x = self.in_conv(x) + + x = self.resnets(x) + + return x + + +class AdapterResnetBlock(nn.Module): + def __init__(self, channels): + super().__init__() + self.block1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) + self.act = nn.ReLU() + self.block2 = nn.Conv2d(channels, channels, kernel_size=1) + + def forward(self, x): + h = x + h = self.block1(h) + h = self.act(h) + h = self.block2(h) + + return h + x + + +# light adapter + + +class LightAdapter(nn.Module): + def __init__( + self, + in_channels: int = 3, + channels: List[int] = [320, 640, 1280], + num_res_blocks: int = 4, + downscale_factor: int = 8, + ): + super().__init__() + + in_channels = in_channels * downscale_factor**2 + + self.unshuffle = nn.PixelUnshuffle(downscale_factor) + + self.body = nn.ModuleList( + [ + LightAdapterBlock(in_channels, channels[0], num_res_blocks), + *[ + LightAdapterBlock(channels[i], channels[i + 1], num_res_blocks, down=True) + for i in range(len(channels) - 1) + ], + LightAdapterBlock(channels[-1], channels[-1], num_res_blocks, down=True), + ] + ) + + self.total_downscale_factor = downscale_factor * (2 ** len(channels)) + + def forward(self, x): + x = self.unshuffle(x) + + features = [] + + for block in self.body: + x = block(x) + features.append(x) + + return features + + +class LightAdapterBlock(nn.Module): + def __init__(self, in_channels, out_channels, num_res_blocks, down=False): + super().__init__() + mid_channels = out_channels // 4 + + self.downsample = None + if down: + self.downsample = Downsample2D(in_channels) + + self.in_conv = nn.Conv2d(in_channels, mid_channels, kernel_size=1) + self.resnets = nn.Sequential(*[LightAdapterResnetBlock(mid_channels) for _ in range(num_res_blocks)]) + self.out_conv = nn.Conv2d(mid_channels, out_channels, kernel_size=1) + + def forward(self, x): + if self.downsample is not None: + x = self.downsample(x) + + x = self.in_conv(x) + x = self.resnets(x) + x = self.out_conv(x) + + return x + + +class LightAdapterResnetBlock(nn.Module): + def __init__(self, channels): + super().__init__() + self.block1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) + self.act = nn.ReLU() + self.block2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) + + def forward(self, x): + h = x + h = self.block1(h) + h = self.act(h) + h = self.block2(h) + + return h + x diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index cb3452f4459c..9bdcf8fa8d9c 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -955,6 +955,7 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, + additional_residuals=None, ): output_states = () @@ -1001,6 +1002,10 @@ def custom_forward(*inputs): output_states = output_states + (hidden_states,) + if additional_residuals is not None: + hidden_states += additional_residuals + output_states = output_states[:-1] + (output_states[-1] + additional_residuals,) + if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index dee71bead0f9..516838b91e1a 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -899,9 +899,17 @@ def forward( sample = self.conv_in(sample) # 3. down + + is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None + is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None + down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + additional_kwargs = {} + if is_adapter and len(down_block_additional_residuals) > 0: + additional_kwargs["additional_residuals"] = down_block_additional_residuals.pop(0) + sample, res_samples = downsample_block( hidden_states=sample, temb=emb, @@ -909,13 +917,17 @@ def forward( attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, + **additional_kwargs, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + if is_adapter and len(down_block_additional_residuals) > 0: + sample += down_block_additional_residuals.pop(0) + down_block_res_samples += res_samples - if down_block_additional_residuals is not None: + if is_controlnet: new_down_block_res_samples = () for down_block_res_sample, down_block_additional_residual in zip( @@ -937,8 +949,8 @@ def forward( encoder_attention_mask=encoder_attention_mask, ) - if mid_block_additional_residual is not None: - sample = sample + mid_block_additional_residual + if is_controlnet: + sample += mid_block_additional_residual # 5. up for i, upsample_block in enumerate(self.up_blocks): diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 937ac1b5e3d7..2f1a9b6b93cb 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -80,6 +80,7 @@ from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline from .stable_diffusion import ( CycleDiffusionPipeline, + StableDiffusionAdapterPipeline, StableDiffusionAttendAndExcitePipeline, StableDiffusionDepth2ImgPipeline, StableDiffusionDiffEditPipeline, diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index 33ab05a1dacb..6edfe5a6da4d 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -44,6 +44,7 @@ class StableDiffusionPipelineOutput(BaseOutput): else: from .pipeline_cycle_diffusion import CycleDiffusionPipeline from .pipeline_stable_diffusion import StableDiffusionPipeline + from .pipeline_stable_diffusion_adapter import StableDiffusionAdapterPipeline from .pipeline_stable_diffusion_attend_and_excite import StableDiffusionAttendAndExcitePipeline from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_adapter.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_adapter.py new file mode 100644 index 000000000000..2bf2ac5c029d --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_adapter.py @@ -0,0 +1,799 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import warnings +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL +import torch +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer + +from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, MultiAdapter, T2IAdapter, UNet2DConditionModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + PIL_INTERPOLATION, + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, + replace_example_docstring, +) +from ..pipeline_utils import DiffusionPipeline +from . import StableDiffusionPipelineOutput +from .safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from PIL import Image + >>> from diffusers.utils import load_image + + >>> image = load_image("https://huggingface.co/RzZ/sd-v1-4-adapter-color/resolve/main/color_ref.png") + + >>> color_palette = image.resize((8, 8)) + >>> color_palette = color_palette.resize((512, 512), resample=Image.Resampling.NEAREST) + + >>> import torch + >>> from diffusers import StableDiffusionAdapterPipeline, T2IAdapter + + >>> adapter = T2IAdapter.from_pretrained("RzZ/sd-v1-4-adapter-color") + >>> pipe = StableDiffusionAdapterPipeline.from_pretrained( + ... "CompVis/stable-diffusion-v1-4", + ... adapter=adapter, + ... torch_dtype=torch.float16, + ... ) + + >>> pipe.to("cuda") + + >>> out_image = pipe( + ... "At night, glowing cubes in front of the beach", + ... image=color_palette, + ... generator=generator, + ... ).images[0] + ``` +""" + + +def preprocess(image, height, width): + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + image = [np.array(i.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])) for i in image] + image = [ + i[None, ..., None] if i.ndim == 2 else i[None, ...] for i in image + ] # expand [h, w] or [h, w, c] to [b, h, w, c] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + if image[0].ndim == 3: + image = torch.stack(image, dim=0) + elif image[0].ndim == 4: + image = torch.cat(image, dim=0) + else: + raise ValueError( + f"Invalid image tensor! Expecting image tensor with 3 or 4 dimension, but recive: {image[0].ndim}" + ) + return image + + +class StableDiffusionAdapterPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using Stable Diffusion augmented with T2I-Adapter + https://arxiv.org/abs/2302.08453 + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + adapter ([`T2IAdapter`] or [`MultiAdapter`] or `List[T2IAdapter]`): + Provides additional conditioning to the unet during the denoising process. If you set multiple Adapter as a + list, the outputs from each Adapter are added together to create one combined additional conditioning. + adapter_weights (`List[float]`, *optional*, defaults to None): + List of floats representing the weight which will be multiply to each adapter's output before adding them + together. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + adapter: Union[T2IAdapter, MultiAdapter, List[T2IAdapter]], + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + adapter_weights: Optional[List[float]] = None, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + if isinstance(adapter, (list, tuple)): + adapter = MultiAdapter(adapter, adapter_weights=adapter_weights) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + adapter=adapter, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): + from accelerate import cpu_offload + else: + raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.adapter]: + cpu_offload(cpu_offloaded_model, device) + + if self.safety_checker is not None: + cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.adapter, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def _default_height_width(self, height, width, image): + # NOTE: It is possible that a list of images have different + # dimensions for each image, so just checking the first image + # is not _exactly_ correct, but it is simple. + while isinstance(image, list): + image = image[0] + + if height is None: + if isinstance(image, PIL.Image.Image): + height = image.height + elif isinstance(image, torch.Tensor): + height = image.shape[-2] + + # round down to nearest multiple of `self.adapter.total_downscale_factor` + height = (height // self.adapter.total_downscale_factor) * self.adapter.total_downscale_factor + + if width is None: + if isinstance(image, PIL.Image.Image): + width = image.width + elif isinstance(image, torch.Tensor): + width = image.shape[-1] + + # round down to nearest multiple of `self.adapter.total_downscale_factor` + width = (width // self.adapter.total_downscale_factor) * self.adapter.total_downscale_factor + + return height, width + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Union[torch.Tensor, PIL.Image.Image, List[PIL.Image.Image]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + adapter_conditioning_scale: Union[float, List[float]] = 1.0, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]` or `List[PIL.Image.Image]` or `List[List[PIL.Image.Image]]`): + The Adapter input condition. Adapter uses this input condition to generate guidance to Unet. If the + type is specified as `Torch.FloatTensor`, it is passed to Adapter as is. PIL.Image.Image` can also be + accepted as an image. The control image is automatically resized to fit the output image. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + adapter_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the adapter are multiplied by `adapter_conditioning_scale` before they are added to the + residual in the original unet. If multiple adapters are specified in init, you can set the + corresponding scale as a list. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height, width = self._default_height_width(height, width, image) + device = self._execution_device + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + is_multi_adapter = isinstance(self.adapter, MultiAdapter) + if is_multi_adapter: + adapter_input = [preprocess(img, height, width).to(device) for img in image] + n, c, h, w = adapter_input[0].shape + adapter_input = torch.stack([x.reshape([n * c, h, w]) for x in adapter_input]) + else: + adapter_input = preprocess(image, height, width).to(device) + adapter_input = adapter_input.to(self.adapter.dtype) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + adapter_state = self.adapter(adapter_input) + for k, v in enumerate(adapter_state): + adapter_state[k] = v * adapter_conditioning_scale + if num_images_per_prompt > 1: + for k, v in enumerate(adapter_state): + adapter_state[k] = v.repeat(num_images_per_prompt, 1, 1, 1) + if do_classifier_free_guidance: + for k, v in enumerate(adapter_state): + adapter_state[k] = torch.cat([v] * 2, dim=0) + + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + down_block_additional_residuals=[state.clone() for state in adapter_state], + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if output_type == "latent": + image = latents + has_nsfw_concept = None + elif output_type == "pil": + # 8. Post-processing + image = self.decode_latents(latents) + + # 9. Run safety checker + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # 10. Convert to PIL + image = self.numpy_to_pil(image) + else: + # 8. Post-processing + image = self.decode_latents(latents) + + # 9. Run safety checker + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 0a2fad6aee1a..8f8ea6c88414 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -1012,9 +1012,17 @@ def forward( sample = self.conv_in(sample) # 3. down + + is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None + is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None + down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + additional_kwargs = {} + if is_adapter and len(down_block_additional_residuals) > 0: + additional_kwargs["additional_residuals"] = down_block_additional_residuals.pop(0) + sample, res_samples = downsample_block( hidden_states=sample, temb=emb, @@ -1022,13 +1030,17 @@ def forward( attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, + **additional_kwargs, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + if is_adapter and len(down_block_additional_residuals) > 0: + sample += down_block_additional_residuals.pop(0) + down_block_res_samples += res_samples - if down_block_additional_residuals is not None: + if is_controlnet: new_down_block_res_samples = () for down_block_res_sample, down_block_additional_residual in zip( @@ -1050,8 +1062,8 @@ def forward( encoder_attention_mask=encoder_attention_mask, ) - if mid_block_additional_residual is not None: - sample = sample + mid_block_additional_residual + if is_controlnet: + sample += mid_block_additional_residual # 5. up for i, upsample_block in enumerate(self.up_blocks): @@ -1390,6 +1402,7 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, + additional_residuals=None, ): output_states = () @@ -1436,6 +1449,10 @@ def custom_forward(*inputs): output_states = output_states + (hidden_states,) + if additional_residuals is not None: + hidden_states += additional_residuals + output_states = output_states[:-1] + (output_states[-1] + additional_residuals,) + if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 20dbf84681d3..b955ec5320de 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -47,6 +47,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class MultiAdapter(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class PriorTransformer(metaclass=DummyObject): _backends = ["torch"] @@ -62,6 +77,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class T2IAdapter(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class T5FilmDecoder(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 164206d776fa..016760337c69 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -407,6 +407,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class StableDiffusionAdapterPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class StableDiffusionAttendAndExcitePipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_adapter.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_adapter.py new file mode 100644 index 000000000000..f40ef36b6e2e --- /dev/null +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_adapter.py @@ -0,0 +1,316 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import random +import unittest + +import numpy as np +import torch +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer + +from diffusers import ( + AutoencoderKL, + PNDMScheduler, + StableDiffusionAdapterPipeline, + T2IAdapter, + UNet2DConditionModel, +) +from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device +from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu + +from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class AdapterTests: + pipeline_class = StableDiffusionAdapterPipeline + params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS + batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS + + def get_dummy_components(self, adapter_type): + torch.manual_seed(0) + unet = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("CrossAttnDownBlock2D", "DownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + ) + scheduler = PNDMScheduler(skip_prk_steps=True) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + torch.manual_seed(0) + adapter = T2IAdapter( + in_channels=3, + channels=[32, 64], + num_res_blocks=2, + downscale_factor=2, + adapter_type=adapter_type, + ) + + components = { + "adapter": adapter, + "unet": unet, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "safety_checker": None, + "feature_extractor": None, + } + return components + + def get_dummy_inputs(self, device, seed=0): + image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device) + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "image": image, + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "output_type": "numpy", + } + return inputs + + def test_attention_slicing_forward_pass(self): + return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3) + + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_xformers_attention_forwardGenerator_pass(self): + self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=2e-3) + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(expected_max_diff=2e-3) + + +class StableDiffusionFullAdapterPipelineFastTests(AdapterTests, PipelineTesterMixin, unittest.TestCase): + def get_dummy_components(self): + return super().get_dummy_components("full_adapter") + + def test_stable_diffusion_adapter_default_case(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + sd_pipe = StableDiffusionAdapterPipeline(**components) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = sd_pipe(**inputs).images + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array([0.4858, 0.5500, 0.4278, 0.4669, 0.6184, 0.4322, 0.5010, 0.5033, 0.4746]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-3 + + +class StableDiffusionLightAdapterPipelineFastTests(AdapterTests, PipelineTesterMixin, unittest.TestCase): + def get_dummy_components(self): + return super().get_dummy_components("light_adapter") + + def test_stable_diffusion_adapter_default_case(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + sd_pipe = StableDiffusionAdapterPipeline(**components) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = sd_pipe(**inputs).images + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array([0.4965, 0.5548, 0.4330, 0.4771, 0.6226, 0.4382, 0.5037, 0.5071, 0.4782]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-3 + + +@slow +@require_torch_gpu +class StableDiffusionAdapterPipelineSlowTests(unittest.TestCase): + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_stable_diffusion_adapter(self): + test_cases = [ + ( + "diffusers/t2iadapter_color_sd14v1", + "CompVis/stable-diffusion-v1-4", + "snail", + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/color.png", + 3, + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_color_sd14v1.npy", + ), + ( + "diffusers/t2iadapter_depth_sd14v1", + "CompVis/stable-diffusion-v1-4", + "desk", + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/desk_depth.png", + 3, + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_depth_sd14v1.npy", + ), + ( + "diffusers/t2iadapter_depth_sd15v2", + "runwayml/stable-diffusion-v1-5", + "desk", + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/desk_depth.png", + 3, + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_depth_sd15v2.npy", + ), + ( + "diffusers/t2iadapter_keypose_sd14v1", + "CompVis/stable-diffusion-v1-4", + "person", + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/person_keypose.png", + 3, + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_keypose_sd14v1.npy", + ), + ( + "diffusers/t2iadapter_openpose_sd14v1", + "CompVis/stable-diffusion-v1-4", + "person", + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/iron_man_pose.png", + 3, + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_openpose_sd14v1.npy", + ), + ( + "diffusers/t2iadapter_seg_sd14v1", + "CompVis/stable-diffusion-v1-4", + "motorcycle", + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/motor.png", + 3, + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_seg_sd14v1.npy", + ), + ( + "diffusers/t2iadapter_zoedepth_sd15v1", + "runwayml/stable-diffusion-v1-5", + "motorcycle", + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/motorcycle.png", + 3, + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_zoedepth_sd15v1.npy", + ), + ( + "diffusers/t2iadapter_canny_sd14v1", + "CompVis/stable-diffusion-v1-4", + "toy", + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/toy_canny.png", + 1, + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_canny_sd14v1.npy", + ), + ( + "diffusers/t2iadapter_canny_sd15v2", + "runwayml/stable-diffusion-v1-5", + "toy", + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/toy_canny.png", + 1, + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_canny_sd15v2.npy", + ), + ( + "diffusers/t2iadapter_sketch_sd14v1", + "CompVis/stable-diffusion-v1-4", + "cat", + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/edge.png", + 1, + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_sketch_sd14v1.npy", + ), + ( + "diffusers/t2iadapter_sketch_sd15v2", + "runwayml/stable-diffusion-v1-5", + "cat", + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/edge.png", + 1, + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_sketch_sd15v2.npy", + ), + ] + + for adapter_model, sd_model, prompt, image_url, input_channels, out_url in test_cases: + image = load_image(image_url) + expected_out = load_numpy(out_url) + + if input_channels == 1: + image = image.convert("L") + + adapter = T2IAdapter.from_pretrained(adapter_model, torch_dtype=torch.float16) + + pipe = StableDiffusionAdapterPipeline.from_pretrained(sd_model, adapter=adapter, safety_checker=None) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + generator = torch.Generator(device="cpu").manual_seed(0) + + out = pipe(prompt=prompt, image=image, generator=generator, num_inference_steps=2, output_type="np").images + + self.assertTrue(np.allclose(out, expected_out)) + + def test_stable_diffusion_adapter_pipeline_with_sequential_cpu_offloading(self): + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + + adapter = T2IAdapter.from_pretrained("diffusers/t2iadapter_seg_sd14v1") + pipe = StableDiffusionAdapterPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", adapter=adapter, safety_checker=None + ) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing(1) + pipe.enable_sequential_cpu_offload() + + image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/motor.png" + ) + + pipe(prompt="foo", image=image, num_inference_steps=2) + + mem_bytes = torch.cuda.max_memory_allocated() + assert mem_bytes < 5 * 10**9 From feafd19876d77c408b33959fed0304c357b0cc99 Mon Sep 17 00:00:00 2001 From: William Berman Date: Tue, 11 Jul 2023 10:57:57 -0700 Subject: [PATCH 02/10] fix sample inplace add --- src/diffusers/models/unet_2d_condition.py | 2 +- .../pipelines/versatile_diffusion/modeling_text_unet.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 516838b91e1a..d307209cf873 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -950,7 +950,7 @@ def forward( ) if is_controlnet: - sample += mid_block_additional_residual + sample = sample + mid_block_additional_residual # 5. up for i, upsample_block in enumerate(self.up_blocks): diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 8f8ea6c88414..32ad92f436eb 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -1063,7 +1063,7 @@ def forward( ) if is_controlnet: - sample += mid_block_additional_residual + sample = sample + mid_block_additional_residual # 5. up for i, upsample_block in enumerate(self.up_blocks): From b75b507830c4bb964ccadef22665b84e3e238abb Mon Sep 17 00:00:00 2001 From: William Berman Date: Tue, 11 Jul 2023 11:02:39 -0700 Subject: [PATCH 03/10] additional_kwargs -> additional_residuals --- src/diffusers/models/unet_2d_condition.py | 7 ++++--- .../pipelines/versatile_diffusion/modeling_text_unet.py | 7 ++++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index d307209cf873..d7756ab5edb3 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -906,9 +906,10 @@ def forward( down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: - additional_kwargs = {} + # For t2i-adapter CrossAttnDownBlock2D + additional_residuals = {} if is_adapter and len(down_block_additional_residuals) > 0: - additional_kwargs["additional_residuals"] = down_block_additional_residuals.pop(0) + additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0) sample, res_samples = downsample_block( hidden_states=sample, @@ -917,7 +918,7 @@ def forward( attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, - **additional_kwargs, + **additional_residuals, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb) diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 32ad92f436eb..5a20834f9802 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -1019,9 +1019,10 @@ def forward( down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: - additional_kwargs = {} + # For t2i-adapter CrossAttnDownBlockFlat + additional_residuals = {} if is_adapter and len(down_block_additional_residuals) > 0: - additional_kwargs["additional_residuals"] = down_block_additional_residuals.pop(0) + additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0) sample, res_samples = downsample_block( hidden_states=sample, @@ -1030,7 +1031,7 @@ def forward( attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, - **additional_kwargs, + **additional_residuals, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb) From b0f356c5a576a663d8ce175ef63c7c9e987fcbe0 Mon Sep 17 00:00:00 2001 From: William Berman Date: Tue, 11 Jul 2023 11:10:08 -0700 Subject: [PATCH 04/10] move t2i adapter pipeline to own module --- src/diffusers/pipelines/__init__.py | 2 +- .../pipelines/stable_diffusion/__init__.py | 1 - .../pipelines/t2i_adapter/__init__.py | 14 +++++++ .../pipeline_stable_diffusion_adapter.py | 37 ++++++++++++++----- 4 files changed, 42 insertions(+), 12 deletions(-) create mode 100644 src/diffusers/pipelines/t2i_adapter/__init__.py rename src/diffusers/pipelines/{stable_diffusion => t2i_adapter}/pipeline_stable_diffusion_adapter.py (96%) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 2f1a9b6b93cb..aa09e7e81130 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -80,7 +80,6 @@ from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline from .stable_diffusion import ( CycleDiffusionPipeline, - StableDiffusionAdapterPipeline, StableDiffusionAttendAndExcitePipeline, StableDiffusionDepth2ImgPipeline, StableDiffusionDiffEditPipeline, @@ -102,6 +101,7 @@ StableUnCLIPPipeline, ) from .stable_diffusion_safe import StableDiffusionPipelineSafe + from .t2i_adapter import StableDiffusionAdapterPipeline from .text_to_video_synthesis import TextToVideoSDPipeline, TextToVideoZeroPipeline, VideoToVideoSDPipeline from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline from .unidiffuser import ImageTextPipelineOutput, UniDiffuserModel, UniDiffuserPipeline, UniDiffuserTextDecoder diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index 6edfe5a6da4d..33ab05a1dacb 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -44,7 +44,6 @@ class StableDiffusionPipelineOutput(BaseOutput): else: from .pipeline_cycle_diffusion import CycleDiffusionPipeline from .pipeline_stable_diffusion import StableDiffusionPipeline - from .pipeline_stable_diffusion_adapter import StableDiffusionAdapterPipeline from .pipeline_stable_diffusion_attend_and_excite import StableDiffusionAttendAndExcitePipeline from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline diff --git a/src/diffusers/pipelines/t2i_adapter/__init__.py b/src/diffusers/pipelines/t2i_adapter/__init__.py new file mode 100644 index 000000000000..c4de661dbefa --- /dev/null +++ b/src/diffusers/pipelines/t2i_adapter/__init__.py @@ -0,0 +1,14 @@ +from ...utils import ( + OptionalDependencyNotAvailable, + is_torch_available, + is_transformers_available, +) + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 +else: + from .pipeline_stable_diffusion_adapter import StableDiffusionAdapterPipeline diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py similarity index 96% rename from src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_adapter.py rename to src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py index 2bf2ac5c029d..32c6aca5f94d 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_adapter.py +++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py @@ -14,6 +14,7 @@ import inspect import warnings +from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Union import numpy as np @@ -26,6 +27,7 @@ from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( PIL_INTERPOLATION, + BaseOutput, is_accelerate_available, is_accelerate_version, logging, @@ -33,8 +35,23 @@ replace_example_docstring, ) from ..pipeline_utils import DiffusionPipeline -from . import StableDiffusionPipelineOutput -from .safety_checker import StableDiffusionSafetyChecker +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker + + +@dataclass +class StableDiffusionAdapterPipelineOutput(BaseOutput): + """ + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + nsfw_content_detected (`List[bool]`) + List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, or `None` if safety checking could not be performed. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + nsfw_content_detected: Optional[List[bool]] logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -641,8 +658,8 @@ def __call__( The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionAdapterPipelineOutput`] instead + of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. @@ -661,11 +678,11 @@ def __call__( Examples: Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. + [`~pipelines.stable_diffusion.StableDiffusionAdapterPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionAdapterPipelineOutput`] if `return_dict` is True, otherwise a + `tuple. When returning a tuple, the first element is a list with the generated images, and the second + element is a list of `bool`s denoting whether the corresponding generated image likely represents + "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 0. Default height and width to unet height, width = self._default_height_width(height, width, image) @@ -796,4 +813,4 @@ def __call__( if not return_dict: return (image, has_nsfw_concept) - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + return StableDiffusionAdapterPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) From bd2201fa593e2cf88f1f21432268a2fe0a62a422 Mon Sep 17 00:00:00 2001 From: William Berman Date: Tue, 11 Jul 2023 11:51:38 -0700 Subject: [PATCH 05/10] preprocess -> _preprocess_adapter_image --- .../t2i_adapter/pipeline_stable_diffusion_adapter.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py index 32c6aca5f94d..5e77dafce63d 100644 --- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py @@ -88,7 +88,7 @@ class StableDiffusionAdapterPipelineOutput(BaseOutput): """ -def preprocess(image, height, width): +def _preprocess_adapter_image(image, height, width): if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): @@ -695,11 +695,11 @@ def __call__( is_multi_adapter = isinstance(self.adapter, MultiAdapter) if is_multi_adapter: - adapter_input = [preprocess(img, height, width).to(device) for img in image] + adapter_input = [_preprocess_adapter_image(img, height, width).to(device) for img in image] n, c, h, w = adapter_input[0].shape adapter_input = torch.stack([x.reshape([n * c, h, w]) for x in adapter_input]) else: - adapter_input = preprocess(image, height, width).to(device) + adapter_input = _preprocess_adapter_image(image, height, width).to(device) adapter_input = adapter_input.to(self.adapter.dtype) # 2. Define call parameters From d66631a35f446dce9e915d123a2e3835561e502c Mon Sep 17 00:00:00 2001 From: William Berman Date: Tue, 11 Jul 2023 11:54:22 -0700 Subject: [PATCH 06/10] add TencentArc to license --- .../pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py index 5e77dafce63d..a0ed16b82b1b 100644 --- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py @@ -1,4 +1,4 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. +# Copyright 2023 TencentARC and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From f7069180b9ced38a2cb62f04f9a7c8f02bee9c98 Mon Sep 17 00:00:00 2001 From: William Berman Date: Mon, 17 Jul 2023 11:23:02 -0700 Subject: [PATCH 07/10] fix example code links --- docs/source/en/_toctree.yml | 4 ++-- .../t2i_adapter/pipeline_stable_diffusion_adapter.py | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index f477f288bbe0..b60933c05228 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -196,12 +196,12 @@ title: DDIM - local: api/pipelines/ddpm title: DDPM + - local: api/pipelines/deepfloyd_if + title: DeepFloyd IF - local: api/pipelines/diffedit title: DiffEdit - local: api/pipelines/dit title: DiT - - local: api/pipelines/deepfloyd_if - title: DeepFloyd IF - local: api/pipelines/pix2pix title: InstructPix2Pix - local: api/pipelines/kandinsky diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py index a0ed16b82b1b..629154ede0b1 100644 --- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py @@ -62,7 +62,9 @@ class StableDiffusionAdapterPipelineOutput(BaseOutput): >>> from PIL import Image >>> from diffusers.utils import load_image - >>> image = load_image("https://huggingface.co/RzZ/sd-v1-4-adapter-color/resolve/main/color_ref.png") + >>> image = load_image( + ... "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/color_ref.png" + ... ) >>> color_palette = image.resize((8, 8)) >>> color_palette = color_palette.resize((512, 512), resample=Image.Resampling.NEAREST) @@ -70,7 +72,7 @@ class StableDiffusionAdapterPipelineOutput(BaseOutput): >>> import torch >>> from diffusers import StableDiffusionAdapterPipeline, T2IAdapter - >>> adapter = T2IAdapter.from_pretrained("RzZ/sd-v1-4-adapter-color") + >>> adapter = T2IAdapter.from_pretrained("TencentARC/t2iadapter_color_sd14v1") >>> pipe = StableDiffusionAdapterPipeline.from_pretrained( ... "CompVis/stable-diffusion-v1-4", ... adapter=adapter, From ec031943f76ee87e3d9e5352f91d5747157c1699 Mon Sep 17 00:00:00 2001 From: Will Berman Date: Mon, 17 Jul 2023 18:35:09 +0000 Subject: [PATCH 08/10] add image converter and fix example doc string --- .../t2i_adapter/pipeline_stable_diffusion_adapter.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py index 629154ede0b1..c84c99e6d19a 100644 --- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py @@ -22,6 +22,7 @@ import torch from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, MultiAdapter, T2IAdapter, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers @@ -61,6 +62,8 @@ class StableDiffusionAdapterPipelineOutput(BaseOutput): ```py >>> from PIL import Image >>> from diffusers.utils import load_image + >>> import torch + >>> from diffusers import StableDiffusionAdapterPipeline, T2IAdapter >>> image = load_image( ... "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/color_ref.png" @@ -69,10 +72,7 @@ class StableDiffusionAdapterPipelineOutput(BaseOutput): >>> color_palette = image.resize((8, 8)) >>> color_palette = color_palette.resize((512, 512), resample=Image.Resampling.NEAREST) - >>> import torch - >>> from diffusers import StableDiffusionAdapterPipeline, T2IAdapter - - >>> adapter = T2IAdapter.from_pretrained("TencentARC/t2iadapter_color_sd14v1") + >>> adapter = T2IAdapter.from_pretrained("TencentARC/t2iadapter_color_sd14v1", torch_dtype=torch.float16) >>> pipe = StableDiffusionAdapterPipeline.from_pretrained( ... "CompVis/stable-diffusion-v1-4", ... adapter=adapter, @@ -84,7 +84,6 @@ class StableDiffusionAdapterPipelineOutput(BaseOutput): >>> out_image = pipe( ... "At night, glowing cubes in front of the beach", ... image=color_palette, - ... generator=generator, ... ).images[0] ``` """ @@ -198,6 +197,7 @@ def __init__( feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing From 80a14528f7b759f13ce050759c305be8a9db7338 Mon Sep 17 00:00:00 2001 From: Will Berman Date: Mon, 17 Jul 2023 18:58:08 +0000 Subject: [PATCH 09/10] fix links --- .../pipelines/stable_diffusion/adapter.mdx | 30 +++++++++---------- .../test_stable_diffusion_adapter.py | 24 +++++++-------- 2 files changed, 27 insertions(+), 27 deletions(-) diff --git a/docs/source/en/api/pipelines/stable_diffusion/adapter.mdx b/docs/source/en/api/pipelines/stable_diffusion/adapter.mdx index 6e07e55250e6..19351e1713b6 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/adapter.mdx +++ b/docs/source/en/api/pipelines/stable_diffusion/adapter.mdx @@ -38,7 +38,7 @@ All adapters use the same pipeline. 1. Images are first converted into the appropriate *control image* format. 2. The *control image* and *prompt* are passed to the [`StableDiffusionAdapterPipeline`]. -Let's have a look at a simple example using the [Color Adapter](https://huggingface.co/diffusers/t2iadapter_color_sd14v1). +Let's have a look at a simple example using the [Color Adapter](https://huggingface.co/TencentARC/t2iadapter_color_sd14v1). ```python from diffusers.utils import load_image @@ -69,7 +69,7 @@ Next, create the adapter pipeline import torch from diffusers import StableDiffusionAdapterPipeline, T2IAdapter -adapter = T2IAdapter.from_pretrained("diffusers/t2iadapter_color_sd14v1") +adapter = T2IAdapter.from_pretrained("TencentARC/t2iadapter_color_sd14v1") pipe = StableDiffusionAdapterPipeline.from_pretrained( "CompVis/stable-diffusion-v1-4", adapter=adapter, @@ -102,17 +102,17 @@ Non-diffusers checkpoints can be found under [TencentARC/T2I-Adapter](https://hu | Model Name | Control Image Overview| Control Image Example | Generated Image Example | |---|---|---|---| -|[diffusers/t2iadapter_color_sd14v1](https://huggingface.co/diffusers/t2iadapter_color_sd14v1)
*Trained with spatial color palette* | A image with 8x8 color palette.||| -|[diffusers/t2iadapter_canny_sd14v1](https://huggingface.co/diffusers/t2iadapter_canny_sd14v1)
*Trained with canny edge detection* | A monochrome image with white edges on a black background.||| -|[diffusers/t2iadapter_sketch_sd14v1](https://huggingface.co/diffusers/t2iadapter_sketch_sd14v1)
*Trained with [PidiNet](https://github.com/zhuoinoulu/pidinet) edge detection* | A hand-drawn monochrome image with white outlines on a black background.||| -|[diffusers/t2iadapter_depth_sd14v1](https://huggingface.co/diffusers/t2iadapter_depth_sd14v1)
*Trained with Midas depth estimation* | A grayscale image with black representing deep areas and white representing shallow areas.||| -|[diffusers/t2iadapter_openpose_sd14v1](https://huggingface.co/diffusers/t2iadapter_openpose_sd14v1)
*Trained with OpenPose bone image* | A [OpenPose bone](https://github.com/CMU-Perceptual-Computing-Lab/openpose) image.||| -|[diffusers/t2iadapter_keypose_sd14v1](https://huggingface.co/diffusers/t2iadapter_keypose_sd14v1)
*Trained with mmpose skeleton image* | A [mmpose skeleton](https://github.com/open-mmlab/mmpose) image.||| -|[diffusers/t2iadapter_seg_sd14v1](https://huggingface.co/diffusers/t2iadapter_seg_sd14v1)
*Trained with semantic segmentation* | An [custom](https://github.com/TencentARC/T2I-Adapter/discussions/25) segmentation protocol image.|| | -|[diffusers/t2iadapter_canny_sd15v2](https://huggingface.co/diffusers/t2iadapter_canny_sd15v2)|| -|[diffusers/t2iadapter_depth_sd15v2](https://huggingface.co/diffusers/t2iadapter_depth_sd15v2)|| -|[diffusers/t2iadapter_sketch_sd15v2](https://huggingface.co/diffusers/t2iadapter_sketch_sd15v2)|| -|[diffusers/t2iadapter_zoedepth_sd15v1](https://huggingface.co/diffusers/t2iadapter_zoedepth_sd15v1)|| +|[TencentARC/t2iadapter_color_sd14v1](https://huggingface.co/TencentARC/t2iadapter_color_sd14v1)
*Trained with spatial color palette* | A image with 8x8 color palette.||| +|[TencentARC/t2iadapter_canny_sd14v1](https://huggingface.co/TencentARC/t2iadapter_canny_sd14v1)
*Trained with canny edge detection* | A monochrome image with white edges on a black background.||| +|[TencentARC/t2iadapter_sketch_sd14v1](https://huggingface.co/TencentARC/t2iadapter_sketch_sd14v1)
*Trained with [PidiNet](https://github.com/zhuoinoulu/pidinet) edge detection* | A hand-drawn monochrome image with white outlines on a black background.||| +|[TencentARC/t2iadapter_depth_sd14v1](https://huggingface.co/TencentARC/t2iadapter_depth_sd14v1)
*Trained with Midas depth estimation* | A grayscale image with black representing deep areas and white representing shallow areas.||| +|[TencentARC/t2iadapter_openpose_sd14v1](https://huggingface.co/TencentARC/t2iadapter_openpose_sd14v1)
*Trained with OpenPose bone image* | A [OpenPose bone](https://github.com/CMU-Perceptual-Computing-Lab/openpose) image.||| +|[TencentARC/t2iadapter_keypose_sd14v1](https://huggingface.co/TencentARC/t2iadapter_keypose_sd14v1)
*Trained with mmpose skeleton image* | A [mmpose skeleton](https://github.com/open-mmlab/mmpose) image.||| +|[TencentARC/t2iadapter_seg_sd14v1](https://huggingface.co/TencentARC/t2iadapter_seg_sd14v1)
*Trained with semantic segmentation* | An [custom](https://github.com/TencentARC/T2I-Adapter/discussions/25) segmentation protocol image.|| | +|[TencentARC/t2iadapter_canny_sd15v2](https://huggingface.co/TencentARC/t2iadapter_canny_sd15v2)|| +|[TencentARC/t2iadapter_depth_sd15v2](https://huggingface.co/TencentARC/t2iadapter_depth_sd15v2)|| +|[TencentARC/t2iadapter_sketch_sd15v2](https://huggingface.co/TencentARC/t2iadapter_sketch_sd15v2)|| +|[TencentARC/t2iadapter_zoedepth_sd15v1](https://huggingface.co/TencentARC/t2iadapter_zoedepth_sd15v1)|| ## Combining multiple adapters @@ -151,8 +151,8 @@ from diffusers import StableDiffusionAdapterPipeline, MultiAdapter adapters = MultiAdapter( [ - T2IAdapter.from_pretrained("diffusers/t2iadapter_keypose_sd14v1"), - T2IAdapter.from_pretrained("diffusers/t2iadapter_depth_sd14v1"), + T2IAdapter.from_pretrained("TencentARC/t2iadapter_keypose_sd14v1"), + T2IAdapter.from_pretrained("TencentARC/t2iadapter_depth_sd14v1"), ] ) adapters = adapters.to(torch.float16) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_adapter.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_adapter.py index f40ef36b6e2e..0c1dd1cfe87b 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_adapter.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_adapter.py @@ -183,7 +183,7 @@ def tearDown(self): def test_stable_diffusion_adapter(self): test_cases = [ ( - "diffusers/t2iadapter_color_sd14v1", + "TencentARC/t2iadapter_color_sd14v1", "CompVis/stable-diffusion-v1-4", "snail", "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/color.png", @@ -191,7 +191,7 @@ def test_stable_diffusion_adapter(self): "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_color_sd14v1.npy", ), ( - "diffusers/t2iadapter_depth_sd14v1", + "TencentARC/t2iadapter_depth_sd14v1", "CompVis/stable-diffusion-v1-4", "desk", "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/desk_depth.png", @@ -199,7 +199,7 @@ def test_stable_diffusion_adapter(self): "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_depth_sd14v1.npy", ), ( - "diffusers/t2iadapter_depth_sd15v2", + "TencentARC/t2iadapter_depth_sd15v2", "runwayml/stable-diffusion-v1-5", "desk", "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/desk_depth.png", @@ -207,7 +207,7 @@ def test_stable_diffusion_adapter(self): "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_depth_sd15v2.npy", ), ( - "diffusers/t2iadapter_keypose_sd14v1", + "TencentARC/t2iadapter_keypose_sd14v1", "CompVis/stable-diffusion-v1-4", "person", "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/person_keypose.png", @@ -215,7 +215,7 @@ def test_stable_diffusion_adapter(self): "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_keypose_sd14v1.npy", ), ( - "diffusers/t2iadapter_openpose_sd14v1", + "TencentARC/t2iadapter_openpose_sd14v1", "CompVis/stable-diffusion-v1-4", "person", "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/iron_man_pose.png", @@ -223,7 +223,7 @@ def test_stable_diffusion_adapter(self): "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_openpose_sd14v1.npy", ), ( - "diffusers/t2iadapter_seg_sd14v1", + "TencentARC/t2iadapter_seg_sd14v1", "CompVis/stable-diffusion-v1-4", "motorcycle", "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/motor.png", @@ -231,7 +231,7 @@ def test_stable_diffusion_adapter(self): "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_seg_sd14v1.npy", ), ( - "diffusers/t2iadapter_zoedepth_sd15v1", + "TencentARC/t2iadapter_zoedepth_sd15v1", "runwayml/stable-diffusion-v1-5", "motorcycle", "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/motorcycle.png", @@ -239,7 +239,7 @@ def test_stable_diffusion_adapter(self): "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_zoedepth_sd15v1.npy", ), ( - "diffusers/t2iadapter_canny_sd14v1", + "TencentARC/t2iadapter_canny_sd14v1", "CompVis/stable-diffusion-v1-4", "toy", "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/toy_canny.png", @@ -247,7 +247,7 @@ def test_stable_diffusion_adapter(self): "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_canny_sd14v1.npy", ), ( - "diffusers/t2iadapter_canny_sd15v2", + "TencentARC/t2iadapter_canny_sd15v2", "runwayml/stable-diffusion-v1-5", "toy", "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/toy_canny.png", @@ -255,7 +255,7 @@ def test_stable_diffusion_adapter(self): "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_canny_sd15v2.npy", ), ( - "diffusers/t2iadapter_sketch_sd14v1", + "TencentARC/t2iadapter_sketch_sd14v1", "CompVis/stable-diffusion-v1-4", "cat", "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/edge.png", @@ -263,7 +263,7 @@ def test_stable_diffusion_adapter(self): "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_sketch_sd14v1.npy", ), ( - "diffusers/t2iadapter_sketch_sd15v2", + "TencentARC/t2iadapter_sketch_sd15v2", "runwayml/stable-diffusion-v1-5", "cat", "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/edge.png", @@ -297,7 +297,7 @@ def test_stable_diffusion_adapter_pipeline_with_sequential_cpu_offloading(self): torch.cuda.reset_max_memory_allocated() torch.cuda.reset_peak_memory_stats() - adapter = T2IAdapter.from_pretrained("diffusers/t2iadapter_seg_sd14v1") + adapter = T2IAdapter.from_pretrained("TencentARC/t2iadapter_seg_sd14v1") pipe = StableDiffusionAdapterPipeline.from_pretrained( "CompVis/stable-diffusion-v1-4", adapter=adapter, safety_checker=None ) From a640ecfef3ad7dd30cb9dc7bfc31f317c3258db9 Mon Sep 17 00:00:00 2001 From: Will Berman Date: Mon, 17 Jul 2023 19:41:00 +0000 Subject: [PATCH 10/10] clearer additional residual application --- src/diffusers/models/unet_2d_blocks.py | 12 +++++++----- .../versatile_diffusion/modeling_text_unet.py | 12 +++++++----- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 9bdcf8fa8d9c..469e501b814b 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -959,7 +959,9 @@ def forward( ): output_states = () - for resnet, attn in zip(self.resnets, self.attentions): + blocks = list(zip(self.resnets, self.attentions)) + + for i, (resnet, attn) in enumerate(blocks): if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): @@ -1000,11 +1002,11 @@ def custom_forward(*inputs): return_dict=False, )[0] - output_states = output_states + (hidden_states,) + # apply additional residuals to the output of the last pair of resnet and attention blocks + if i == len(blocks) - 1 and additional_residuals is not None: + hidden_states = hidden_states + additional_residuals - if additional_residuals is not None: - hidden_states += additional_residuals - output_states = output_states[:-1] + (output_states[-1] + additional_residuals,) + output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 5a20834f9802..82628104eba2 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -1407,7 +1407,9 @@ def forward( ): output_states = () - for resnet, attn in zip(self.resnets, self.attentions): + blocks = list(zip(self.resnets, self.attentions)) + + for i, (resnet, attn) in enumerate(blocks): if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): @@ -1448,11 +1450,11 @@ def custom_forward(*inputs): return_dict=False, )[0] - output_states = output_states + (hidden_states,) + # apply additional residuals to the output of the last pair of resnet and attention blocks + if i == len(blocks) - 1 and additional_residuals is not None: + hidden_states = hidden_states + additional_residuals - if additional_residuals is not None: - hidden_states += additional_residuals - output_states = output_states[:-1] + (output_states[-1] + additional_residuals,) + output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: