Skip to content

Add support for Multi-ControlNet to StableDiffusionControlNetPipeline #2627

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Mar 13, 2023
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
52b7c69
support for List[ControlNetModel] on init()
takuma104 Mar 8, 2023
95e4a12
Add to support for multiple ControlNetCondition
takuma104 Mar 8, 2023
3c2fd98
rename conditioning_scale to scale
takuma104 Mar 8, 2023
aae2fc1
scaling bugfix
takuma104 Mar 8, 2023
b14b304
Manually merge `MultiControlNet` #2621
takuma104 Mar 9, 2023
6b38c3a
Merge branch 'huggingface:main' into multi-controlnet-ext
takuma104 Mar 9, 2023
dd96189
Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffu…
takuma104 Mar 10, 2023
6eb98e0
Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffu…
takuma104 Mar 10, 2023
9ca4aaa
Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffu…
takuma104 Mar 10, 2023
32c8873
Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffu…
takuma104 Mar 10, 2023
987567d
Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffu…
takuma104 Mar 10, 2023
91affab
Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffu…
takuma104 Mar 10, 2023
d1acef4
cleanups
takuma104 Mar 10, 2023
9f04578
make style error correct
takuma104 Mar 10, 2023
fb39868
remove ControlNetCondition to reduce code diff
takuma104 Mar 10, 2023
9ff7978
refactoring image/cond_scale
takuma104 Mar 10, 2023
dca3b5e
add explain for `images`
takuma104 Mar 10, 2023
22c8661
Merge branch 'huggingface:main' into multi-controlnet-ext
takuma104 Mar 11, 2023
52171e3
Add docstrings
takuma104 Mar 12, 2023
fa2010e
all fast-test passed
takuma104 Mar 12, 2023
4297328
Add a slow test
takuma104 Mar 12, 2023
2528494
nit
takuma104 Mar 12, 2023
a2140aa
Apply suggestions from code review
patrickvonplaten Mar 13, 2023
039db1b
small precision fix
patrickvonplaten Mar 13, 2023
a257ed5
nits
williamberman Mar 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/diffusers/models/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,7 @@ def forward(
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
controlnet_cond: torch.FloatTensor,
conditioning_scale: float = 1.0,
class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -492,6 +493,10 @@ def forward(

mid_block_res_sample = self.controlnet_mid_block(sample)

# 6. scaling
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great!

down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
mid_block_res_sample *= conditioning_scale

if not return_dict:
return (down_block_res_samples, mid_block_res_sample)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,18 @@


import inspect
from typing import Any, Callable, Dict, List, Optional, Union
import os
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import PIL.Image
import torch
from torch import device, nn
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer

from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...models.controlnet import ControlNetOutput
from ...models.modeling_utils import get_parameter_device, get_parameter_dtype
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
PIL_INTERPOLATION,
Expand Down Expand Up @@ -85,6 +89,78 @@
"""


class MultiControlNet(nn.Module):
r"""
Multiple `ControlNetModel` wrapper class for Multi-ControlNet

This module is a wrapper for multiple instances of the `ControlNetModel`. The `forward()` API is designed to be
compatible with `ControlNetModel`.

Args:
controlnets (`List[ControlNetModel]`):
Provides additional conditioning to the unet during the denoising process. You must set multiple
`ControlNetModel` as a list.
"""

def __init__(self, controlnets: List[ControlNetModel]):
super().__init__()
self.nets = nn.ModuleList(controlnets)

@property
def device(self) -> device:
"""
`torch.device`: The device on which the module is (assuming that all the module parameters are on the same
device).
"""
return get_parameter_device(self)

@property
def dtype(self) -> torch.dtype:
"""
`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
"""
return get_parameter_dtype(self)

def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
controlnet_cond: List[torch.tensor],
conditioning_scale: List[float],
class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[ControlNetOutput, Tuple]:
for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
down_samples, mid_sample = controlnet(
sample,
timestep,
encoder_hidden_states,
image,
scale,
class_labels,
timestep_cond,
attention_mask,
cross_attention_kwargs,
return_dict,
)

# merge samples
if i == 0:
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
else:
down_block_res_samples = [
samples_prev + samples_curr
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
]
mid_block_res_sample += mid_sample

return down_block_res_samples, mid_block_res_sample


class StableDiffusionControlNetPipeline(DiffusionPipeline):
r"""
Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
Expand All @@ -103,8 +179,9 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
controlnet ([`ControlNetModel`]):
Provides additional conditioning to the unet during the denoising process
controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):
Provides additional conditioning to the unet during the denoising process. You can set multiple
`ControlNetModel` as a list.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
Expand All @@ -122,7 +199,7 @@ def __init__(
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
controlnet: ControlNetModel,
controlnet: Union[ControlNetModel, List[ControlNetModel]],
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
Expand All @@ -146,6 +223,11 @@ def __init__(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)

if isinstance(controlnet, (list, tuple)):
controlnet = MultiControlNet(controlnet)
else:
controlnet = controlnet

self.register_modules(
vae=vae,
text_encoder=text_encoder,
Expand Down Expand Up @@ -425,7 +507,6 @@ def prepare_extra_step_kwargs(self, generator, eta):
def check_inputs(
self,
prompt,
image,
height,
width,
callback_steps,
Expand Down Expand Up @@ -470,6 +551,7 @@ def check_inputs(
f" {negative_prompt_embeds.shape}."
)

def check_image(self, image, prompt, prompt_embeds):
image_is_pil = isinstance(image, PIL.Image.Image)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactor looks good to me

image_is_tensor = isinstance(image, torch.Tensor)
image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
Expand Down Expand Up @@ -501,7 +583,9 @@ def check_inputs(
f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
)

def prepare_image(self, image, width, height, batch_size, num_images_per_prompt, device, dtype):
def prepare_image(
self, image, width, height, batch_size, num_images_per_prompt, device, dtype, do_classifier_free_guidance
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do_classifier_free_guidance should have a default value to not break existing code that depends on StableDiffusionControlNetPipeline (like StableDiffusionControlNetInpaintPipeline)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed! Since this PR is already closed, could you please open a new PR for it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure

if not isinstance(image, torch.Tensor):
if isinstance(image, PIL.Image.Image):
image = [image]
Expand Down Expand Up @@ -529,6 +613,9 @@ def prepare_image(self, image, width, height, batch_size, num_images_per_prompt,

image = image.to(device=device, dtype=dtype)

if do_classifier_free_guidance:
image = torch.cat([image] * 2)

return image

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
Expand Down Expand Up @@ -571,6 +658,53 @@ def _default_height_width(self, height, width, image):

return height, width

def _prepare_images(self, image):
if isinstance(self.controlnet, ControlNetModel):
return [image] # convert to array for internal use
else: # Multi-Controlnet
if not isinstance(image, list):
raise ValueError("The `image` argument needs to be specified in a `list`.")

num_controlnets = len(self.controlnet.nets)
if len(image) % num_controlnets != 0:
raise ValueError(
"The length of the `image` argument list needs to be a multiple of the number of Multi-ControlNet."
)

image_per_control = len(image) // num_controlnets

# let's split images over controlnets
return [image[i : i + image_per_control] for i in range(0, len(image), image_per_control)]

def _prepare_controlnet_conditioning_scale(self, controlnet_conditioning_scale):
if isinstance(self.controlnet, ControlNetModel):
if not isinstance(controlnet_conditioning_scale, float):
raise ValueError("The `controlnet_conditioning_scale` argument needs to be specified as a `float`.")
return controlnet_conditioning_scale
else: # Multi-Controlnet
num_controlnets = len(self.controlnet.nets)
if isinstance(controlnet_conditioning_scale, list):
if len(controlnet_conditioning_scale) != num_controlnets:
raise ValueError(
"The length of the `controlnet_conditioning_scale` list does not match the number of Multi-ControlNet. "
"If specified in `list`, it needs to have the same length as the number of Multi-ControlNet."
)
else:
controlnet_conditioning_scale = [controlnet_conditioning_scale] * num_controlnets
return controlnet_conditioning_scale

# override DiffusionPipeline
def save_pretrained(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's very clean - great!

self,
save_directory: Union[str, os.PathLike],
safe_serialization: bool = False,
variant: Optional[str] = None,
):
if isinstance(self.controlnet, ControlNetModel):
super().save_pretrained(save_directory, safe_serialization, variant)
else:
raise NotImplementedError("Currently, the `save_pretrained()` is not implemented for Multi-ControlNet.")

@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
Expand All @@ -593,7 +727,7 @@ def __call__(
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_conditioning_scale: float = 1.0,
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
):
r"""
Function invoked when calling the pipeline for generation.
Expand All @@ -604,8 +738,10 @@ def __call__(
instead.
image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]` or `List[PIL.Image.Image]`):
The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. PIL.Image.Image` can
also be accepted as an image. The control image is automatically resized to fit the output image.
the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can
also be accepted as an image. The control image is automatically resized to fit the output image. If
multiple ControlNets are specified in init, you need to set the corresponding images in the form of a
list of `List[torch.FloatTensor]` or `List[PIL.Image.Image]`.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
Expand Down Expand Up @@ -658,10 +794,10 @@ def __call__(
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
`self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
to the residual in the original unet.

to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
corresponding scale as a list.
Examples:

Returns:
Expand All @@ -671,13 +807,21 @@ def __call__(
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""

# prepare `images` and `controlnet_conditioning_scale` for both a ControlNet and Multi-Controlnet
# `images` here is a list where each element is a conditioning image for each ControlNet.
images = self._prepare_images(image)
controlnet_conditioning_scale = self._prepare_controlnet_conditioning_scale(controlnet_conditioning_scale)

# 0. Default height and width to unet
height, width = self._default_height_width(height, width, image)
height, width = self._default_height_width(height, width, images[0])

# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt, image, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
)
for image in images:
self.check_image(image, prompt, prompt_embeds)

# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
Expand Down Expand Up @@ -705,18 +849,19 @@ def __call__(
)

# 4. Prepare image
image = self.prepare_image(
image,
width,
height,
batch_size * num_images_per_prompt,
num_images_per_prompt,
device,
self.controlnet.dtype,
)

if do_classifier_free_guidance:
image = torch.cat([image] * 2)
images = [
self.prepare_image(
image=image,
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=self.controlnet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance,
)
for image in images
]

# 5. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
Expand Down Expand Up @@ -746,20 +891,16 @@ def __call__(
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

# controlnet(s) inference
down_block_res_samples, mid_block_res_sample = self.controlnet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
controlnet_cond=image,
controlnet_cond=images[0] if len(images) == 1 else images,
conditioning_scale=controlnet_conditioning_scale,
return_dict=False,
)

down_block_res_samples = [
Copy link
Contributor

@patrickvonplaten patrickvonplaten Mar 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the clean-up here, but I think since we have to pay attention to this: https://github.com/huggingface/diffusers/pull/2627/files#r1132300008 we should maybe put all this directly in the ControlNetModel forward call? Also see: https://github.com/huggingface/diffusers/pull/2627/files#r1132302783

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a nice idea! I'll write it in that direction.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in d1acef4

down_block_res_sample * controlnet_conditioning_scale
for down_block_res_sample in down_block_res_samples
]
mid_block_res_sample *= controlnet_conditioning_scale

# predict the noise residual
noise_pred = self.unet(
latent_model_input,
Expand Down
Loading