Skip to content

Add support for Multi-ControlNet to StableDiffusionControlNetPipeline #2627

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Mar 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
52b7c69
support for List[ControlNetModel] on init()
takuma104 Mar 8, 2023
95e4a12
Add to support for multiple ControlNetCondition
takuma104 Mar 8, 2023
3c2fd98
rename conditioning_scale to scale
takuma104 Mar 8, 2023
aae2fc1
scaling bugfix
takuma104 Mar 8, 2023
b14b304
Manually merge `MultiControlNet` #2621
takuma104 Mar 9, 2023
6b38c3a
Merge branch 'huggingface:main' into multi-controlnet-ext
takuma104 Mar 9, 2023
dd96189
Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffu…
takuma104 Mar 10, 2023
6eb98e0
Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffu…
takuma104 Mar 10, 2023
9ca4aaa
Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffu…
takuma104 Mar 10, 2023
32c8873
Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffu…
takuma104 Mar 10, 2023
987567d
Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffu…
takuma104 Mar 10, 2023
91affab
Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffu…
takuma104 Mar 10, 2023
d1acef4
cleanups
takuma104 Mar 10, 2023
9f04578
make style error correct
takuma104 Mar 10, 2023
fb39868
remove ControlNetCondition to reduce code diff
takuma104 Mar 10, 2023
9ff7978
refactoring image/cond_scale
takuma104 Mar 10, 2023
dca3b5e
add explain for `images`
takuma104 Mar 10, 2023
22c8661
Merge branch 'huggingface:main' into multi-controlnet-ext
takuma104 Mar 11, 2023
52171e3
Add docstrings
takuma104 Mar 12, 2023
fa2010e
all fast-test passed
takuma104 Mar 12, 2023
4297328
Add a slow test
takuma104 Mar 12, 2023
2528494
nit
takuma104 Mar 12, 2023
a2140aa
Apply suggestions from code review
patrickvonplaten Mar 13, 2023
039db1b
small precision fix
patrickvonplaten Mar 13, 2023
a257ed5
nits
williamberman Mar 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/diffusers/models/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,7 @@ def forward(
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
controlnet_cond: torch.FloatTensor,
conditioning_scale: float = 1.0,
class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -492,6 +493,10 @@ def forward(

mid_block_res_sample = self.controlnet_mid_block(sample)

# 6. scaling
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great!

down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
mid_block_res_sample *= conditioning_scale

if not return_dict:
return (down_block_res_samples, mid_block_res_sample)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,18 @@


import inspect
from typing import Any, Callable, Dict, List, Optional, Union
import os
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import PIL.Image
import torch
from torch import nn
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer

from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...models.controlnet import ControlNetOutput
from ...models.modeling_utils import ModelMixin
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
PIL_INTERPOLATION,
Expand Down Expand Up @@ -85,6 +89,63 @@
"""


class MultiControlNetModel(ModelMixin):
r"""
Multiple `ControlNetModel` wrapper class for Multi-ControlNet

This module is a wrapper for multiple instances of the `ControlNetModel`. The `forward()` API is designed to be
compatible with `ControlNetModel`.

Args:
controlnets (`List[ControlNetModel]`):
Provides additional conditioning to the unet during the denoising process. You must set multiple
`ControlNetModel` as a list.
"""

def __init__(self, controlnets: Union[List[ControlNetModel], Tuple[ControlNetModel]]):
super().__init__()
self.nets = nn.ModuleList(controlnets)

def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
controlnet_cond: List[torch.tensor],
conditioning_scale: List[float],
class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[ControlNetOutput, Tuple]:
for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
down_samples, mid_sample = controlnet(
sample,
timestep,
encoder_hidden_states,
image,
scale,
class_labels,
timestep_cond,
attention_mask,
cross_attention_kwargs,
return_dict,
)

# merge samples
if i == 0:
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
else:
down_block_res_samples = [
samples_prev + samples_curr
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
]
mid_block_res_sample += mid_sample

return down_block_res_samples, mid_block_res_sample


class StableDiffusionControlNetPipeline(DiffusionPipeline):
r"""
Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
Expand All @@ -103,8 +164,10 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
controlnet ([`ControlNetModel`]):
Provides additional conditioning to the unet during the denoising process
controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):
Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets
as a list, the outputs from each ControlNet are added together to create one combined additional
conditioning.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
Expand All @@ -122,7 +185,7 @@ def __init__(
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
controlnet: ControlNetModel,
controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
Expand All @@ -146,6 +209,9 @@ def __init__(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)

if isinstance(controlnet, (list, tuple)):
controlnet = MultiControlNetModel(controlnet)

self.register_modules(
vae=vae,
text_encoder=text_encoder,
Expand Down Expand Up @@ -432,6 +498,7 @@ def check_inputs(
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
controlnet_conditioning_scale=1.0,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
Expand Down Expand Up @@ -470,6 +537,41 @@ def check_inputs(
f" {negative_prompt_embeds.shape}."
)

# Check `image`

if isinstance(self.controlnet, ControlNetModel):
self.check_image(image, prompt, prompt_embeds)
elif isinstance(self.controlnet, MultiControlNetModel):
if not isinstance(image, list):
raise TypeError("For multiple controlnets: `image` must be type `list`")

if len(image) != len(self.controlnet.nets):
raise ValueError(
"For multiple controlnets: `image` must have the same length as the number of controlnets."
)

for image_ in image:
self.check_image(image_, prompt, prompt_embeds)
else:
assert False

# Check `controlnet_conditioning_scale`

if isinstance(self.controlnet, ControlNetModel):
if not isinstance(controlnet_conditioning_scale, float):
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
elif isinstance(self.controlnet, MultiControlNetModel):
if isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
self.controlnet.nets
):
raise ValueError(
"For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
" the same length as the number of controlnets"
)
else:
assert False

def check_image(self, image, prompt, prompt_embeds):
image_is_pil = isinstance(image, PIL.Image.Image)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactor looks good to me

image_is_tensor = isinstance(image, torch.Tensor)
image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
Expand Down Expand Up @@ -501,7 +603,9 @@ def check_inputs(
f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
)

def prepare_image(self, image, width, height, batch_size, num_images_per_prompt, device, dtype):
def prepare_image(
self, image, width, height, batch_size, num_images_per_prompt, device, dtype, do_classifier_free_guidance
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do_classifier_free_guidance should have a default value to not break existing code that depends on StableDiffusionControlNetPipeline (like StableDiffusionControlNetInpaintPipeline)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed! Since this PR is already closed, could you please open a new PR for it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure

if not isinstance(image, torch.Tensor):
if isinstance(image, PIL.Image.Image):
image = [image]
Expand Down Expand Up @@ -529,6 +633,9 @@ def prepare_image(self, image, width, height, batch_size, num_images_per_prompt,

image = image.to(device=device, dtype=dtype)

if do_classifier_free_guidance:
image = torch.cat([image] * 2)

return image

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
Expand All @@ -550,7 +657,10 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
return latents

def _default_height_width(self, height, width, image):
if isinstance(image, list):
# NOTE: It is possible that a list of images have different
# dimensions for each image, so just checking the first image
# is not _exactly_ correct, but it is simple.
while isinstance(image, list):
image = image[0]

if height is None:
Expand All @@ -571,6 +681,18 @@ def _default_height_width(self, height, width, image):

return height, width

# override DiffusionPipeline
def save_pretrained(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's very clean - great!

self,
save_directory: Union[str, os.PathLike],
safe_serialization: bool = False,
variant: Optional[str] = None,
):
if isinstance(self.controlnet, ControlNetModel):
super().save_pretrained(save_directory, safe_serialization, variant)
else:
raise NotImplementedError("Currently, the `save_pretrained()` is not implemented for Multi-ControlNet.")

@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
Expand All @@ -593,7 +715,7 @@ def __call__(
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_conditioning_scale: float = 1.0,
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
):
r"""
Function invoked when calling the pipeline for generation.
Expand All @@ -602,10 +724,14 @@ def __call__(
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]` or `List[PIL.Image.Image]`):
image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`,
`List[List[torch.FloatTensor]]`, or `List[List[PIL.Image.Image]]`):
The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. PIL.Image.Image` can
also be accepted as an image. The control image is automatically resized to fit the output image.
the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can
also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If
height and/or width are passed, `image` is resized according to them. If multiple ControlNets are
specified in init, images must be passed as a list such that each element of the list can be correctly
batched for input to a single controlnet.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
Expand Down Expand Up @@ -658,10 +784,10 @@ def __call__(
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
`self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
to the residual in the original unet.

to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
corresponding scale as a list.
Examples:

Returns:
Expand All @@ -676,7 +802,15 @@ def __call__(

# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt, image, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
prompt,
image,
height,
width,
callback_steps,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
controlnet_conditioning_scale,
)

# 2. Define call parameters
Expand All @@ -693,6 +827,9 @@ def __call__(
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0

if isinstance(self.controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(self.controlnet.nets)

# 3. Encode input prompt
prompt_embeds = self._encode_prompt(
prompt,
Expand All @@ -705,18 +842,37 @@ def __call__(
)

# 4. Prepare image
image = self.prepare_image(
image,
width,
height,
batch_size * num_images_per_prompt,
num_images_per_prompt,
device,
self.controlnet.dtype,
)
if isinstance(self.controlnet, ControlNetModel):
image = self.prepare_image(
image=image,
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=self.controlnet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance,
)
elif isinstance(self.controlnet, MultiControlNetModel):
images = []

for image_ in image:
image_ = self.prepare_image(
image=image_,
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=self.controlnet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance,
)

if do_classifier_free_guidance:
image = torch.cat([image] * 2)
images.append(image_)

image = images
else:
assert False

# 5. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
Expand Down Expand Up @@ -746,20 +902,16 @@ def __call__(
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

# controlnet(s) inference
down_block_res_samples, mid_block_res_sample = self.controlnet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
controlnet_cond=image,
conditioning_scale=controlnet_conditioning_scale,
return_dict=False,
)

down_block_res_samples = [
Copy link
Contributor

@patrickvonplaten patrickvonplaten Mar 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the clean-up here, but I think since we have to pay attention to this: https://github.com/huggingface/diffusers/pull/2627/files#r1132300008 we should maybe put all this directly in the ControlNetModel forward call? Also see: https://github.com/huggingface/diffusers/pull/2627/files#r1132302783

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a nice idea! I'll write it in that direction.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in d1acef4

down_block_res_sample * controlnet_conditioning_scale
for down_block_res_sample in down_block_res_samples
]
mid_block_res_sample *= controlnet_conditioning_scale

# predict the noise residual
noise_pred = self.unet(
latent_model_input,
Expand Down
Loading