-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Add support for Multi-ControlNet to StableDiffusionControlNetPipeline #2627
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 22 commits
52b7c69
95e4a12
3c2fd98
aae2fc1
b14b304
6b38c3a
dd96189
6eb98e0
9ca4aaa
32c8873
987567d
91affab
d1acef4
9f04578
fb39868
9ff7978
dca3b5e
22c8661
52171e3
fa2010e
4297328
2528494
a2140aa
039db1b
a257ed5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,14 +14,18 @@ | |
|
||
|
||
import inspect | ||
from typing import Any, Callable, Dict, List, Optional, Union | ||
import os | ||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union | ||
|
||
import numpy as np | ||
import PIL.Image | ||
import torch | ||
from torch import device, nn | ||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer | ||
|
||
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel | ||
from ...models.controlnet import ControlNetOutput | ||
from ...models.modeling_utils import get_parameter_device, get_parameter_dtype | ||
from ...schedulers import KarrasDiffusionSchedulers | ||
from ...utils import ( | ||
PIL_INTERPOLATION, | ||
|
@@ -85,6 +89,78 @@ | |
""" | ||
|
||
|
||
class MultiControlNet(nn.Module): | ||
r""" | ||
Multiple `ControlNetModel` wrapper class for Multi-ControlNet | ||
|
||
This module is a wrapper for multiple instances of the `ControlNetModel`. The `forward()` API is designed to be | ||
compatible with `ControlNetModel`. | ||
|
||
Args: | ||
controlnets (`List[ControlNetModel]`): | ||
Provides additional conditioning to the unet during the denoising process. You must set multiple | ||
`ControlNetModel` as a list. | ||
""" | ||
|
||
def __init__(self, controlnets: List[ControlNetModel]): | ||
super().__init__() | ||
self.nets = nn.ModuleList(controlnets) | ||
|
||
@property | ||
def device(self) -> device: | ||
""" | ||
`torch.device`: The device on which the module is (assuming that all the module parameters are on the same | ||
device). | ||
""" | ||
return get_parameter_device(self) | ||
|
||
@property | ||
def dtype(self) -> torch.dtype: | ||
""" | ||
`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). | ||
""" | ||
return get_parameter_dtype(self) | ||
|
||
def forward( | ||
self, | ||
sample: torch.FloatTensor, | ||
timestep: Union[torch.Tensor, float, int], | ||
encoder_hidden_states: torch.Tensor, | ||
controlnet_cond: List[torch.tensor], | ||
conditioning_scale: List[float], | ||
class_labels: Optional[torch.Tensor] = None, | ||
timestep_cond: Optional[torch.Tensor] = None, | ||
attention_mask: Optional[torch.Tensor] = None, | ||
cross_attention_kwargs: Optional[Dict[str, Any]] = None, | ||
return_dict: bool = True, | ||
) -> Union[ControlNetOutput, Tuple]: | ||
for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)): | ||
down_samples, mid_sample = controlnet( | ||
sample, | ||
timestep, | ||
encoder_hidden_states, | ||
image, | ||
scale, | ||
class_labels, | ||
timestep_cond, | ||
attention_mask, | ||
patrickvonplaten marked this conversation as resolved.
Show resolved
Hide resolved
|
||
cross_attention_kwargs, | ||
return_dict, | ||
) | ||
|
||
# merge samples | ||
if i == 0: | ||
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample | ||
else: | ||
down_block_res_samples = [ | ||
samples_prev + samples_curr | ||
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples) | ||
] | ||
mid_block_res_sample += mid_sample | ||
|
||
return down_block_res_samples, mid_block_res_sample | ||
|
||
|
||
class StableDiffusionControlNetPipeline(DiffusionPipeline): | ||
r""" | ||
Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. | ||
|
@@ -103,8 +179,9 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline): | |
Tokenizer of class | ||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). | ||
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. | ||
controlnet ([`ControlNetModel`]): | ||
Provides additional conditioning to the unet during the denoising process | ||
controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): | ||
Provides additional conditioning to the unet during the denoising process. You can set multiple | ||
`ControlNetModel` as a list. | ||
scheduler ([`SchedulerMixin`]): | ||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of | ||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. | ||
|
@@ -122,7 +199,7 @@ def __init__( | |
text_encoder: CLIPTextModel, | ||
tokenizer: CLIPTokenizer, | ||
unet: UNet2DConditionModel, | ||
controlnet: ControlNetModel, | ||
controlnet: Union[ControlNetModel, List[ControlNetModel]], | ||
scheduler: KarrasDiffusionSchedulers, | ||
safety_checker: StableDiffusionSafetyChecker, | ||
feature_extractor: CLIPFeatureExtractor, | ||
|
@@ -146,6 +223,11 @@ def __init__( | |
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." | ||
) | ||
|
||
if isinstance(controlnet, (list, tuple)): | ||
controlnet = MultiControlNet(controlnet) | ||
else: | ||
controlnet = controlnet | ||
|
||
self.register_modules( | ||
vae=vae, | ||
text_encoder=text_encoder, | ||
|
@@ -425,7 +507,6 @@ def prepare_extra_step_kwargs(self, generator, eta): | |
def check_inputs( | ||
self, | ||
prompt, | ||
image, | ||
height, | ||
width, | ||
callback_steps, | ||
|
@@ -470,6 +551,7 @@ def check_inputs( | |
f" {negative_prompt_embeds.shape}." | ||
) | ||
|
||
def check_image(self, image, prompt, prompt_embeds): | ||
image_is_pil = isinstance(image, PIL.Image.Image) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Refactor looks good to me |
||
image_is_tensor = isinstance(image, torch.Tensor) | ||
image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) | ||
|
@@ -501,7 +583,9 @@ def check_inputs( | |
f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" | ||
) | ||
|
||
def prepare_image(self, image, width, height, batch_size, num_images_per_prompt, device, dtype): | ||
def prepare_image( | ||
self, image, width, height, batch_size, num_images_per_prompt, device, dtype, do_classifier_free_guidance | ||
): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Indeed! Since this PR is already closed, could you please open a new PR for it? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sure |
||
if not isinstance(image, torch.Tensor): | ||
if isinstance(image, PIL.Image.Image): | ||
image = [image] | ||
|
@@ -529,6 +613,9 @@ def prepare_image(self, image, width, height, batch_size, num_images_per_prompt, | |
|
||
image = image.to(device=device, dtype=dtype) | ||
|
||
if do_classifier_free_guidance: | ||
image = torch.cat([image] * 2) | ||
|
||
return image | ||
|
||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents | ||
|
@@ -571,6 +658,53 @@ def _default_height_width(self, height, width, image): | |
|
||
return height, width | ||
|
||
def _prepare_images(self, image): | ||
if isinstance(self.controlnet, ControlNetModel): | ||
return [image] # convert to array for internal use | ||
else: # Multi-Controlnet | ||
if not isinstance(image, list): | ||
raise ValueError("The `image` argument needs to be specified in a `list`.") | ||
|
||
num_controlnets = len(self.controlnet.nets) | ||
if len(image) % num_controlnets != 0: | ||
raise ValueError( | ||
"The length of the `image` argument list needs to be a multiple of the number of Multi-ControlNet." | ||
) | ||
|
||
image_per_control = len(image) // num_controlnets | ||
|
||
# let's split images over controlnets | ||
return [image[i : i + image_per_control] for i in range(0, len(image), image_per_control)] | ||
|
||
def _prepare_controlnet_conditioning_scale(self, controlnet_conditioning_scale): | ||
if isinstance(self.controlnet, ControlNetModel): | ||
if not isinstance(controlnet_conditioning_scale, float): | ||
raise ValueError("The `controlnet_conditioning_scale` argument needs to be specified as a `float`.") | ||
return controlnet_conditioning_scale | ||
else: # Multi-Controlnet | ||
num_controlnets = len(self.controlnet.nets) | ||
if isinstance(controlnet_conditioning_scale, list): | ||
if len(controlnet_conditioning_scale) != num_controlnets: | ||
raise ValueError( | ||
"The length of the `controlnet_conditioning_scale` list does not match the number of Multi-ControlNet. " | ||
"If specified in `list`, it needs to have the same length as the number of Multi-ControlNet." | ||
) | ||
else: | ||
controlnet_conditioning_scale = [controlnet_conditioning_scale] * num_controlnets | ||
return controlnet_conditioning_scale | ||
|
||
# override DiffusionPipeline | ||
def save_pretrained( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's very clean - great! |
||
self, | ||
save_directory: Union[str, os.PathLike], | ||
safe_serialization: bool = False, | ||
variant: Optional[str] = None, | ||
): | ||
if isinstance(self.controlnet, ControlNetModel): | ||
super().save_pretrained(save_directory, safe_serialization, variant) | ||
else: | ||
raise NotImplementedError("Currently, the `save_pretrained()` is not implemented for Multi-ControlNet.") | ||
|
||
@torch.no_grad() | ||
@replace_example_docstring(EXAMPLE_DOC_STRING) | ||
def __call__( | ||
|
@@ -593,7 +727,7 @@ def __call__( | |
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, | ||
callback_steps: int = 1, | ||
cross_attention_kwargs: Optional[Dict[str, Any]] = None, | ||
controlnet_conditioning_scale: float = 1.0, | ||
controlnet_conditioning_scale: Union[float, List[float]] = 1.0, | ||
): | ||
r""" | ||
Function invoked when calling the pipeline for generation. | ||
|
@@ -604,8 +738,10 @@ def __call__( | |
instead. | ||
image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]` or `List[PIL.Image.Image]`): | ||
The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If | ||
the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. PIL.Image.Image` can | ||
also be accepted as an image. The control image is automatically resized to fit the output image. | ||
the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can | ||
also be accepted as an image. The control image is automatically resized to fit the output image. If | ||
multiple ControlNets are specified in init, you need to set the corresponding images in the form of a | ||
list of `List[torch.FloatTensor]` or `List[PIL.Image.Image]`. | ||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): | ||
The height in pixels of the generated image. | ||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): | ||
|
@@ -658,10 +794,10 @@ def __call__( | |
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under | ||
`self.processor` in | ||
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). | ||
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): | ||
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): | ||
The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added | ||
to the residual in the original unet. | ||
|
||
to the residual in the original unet. If multiple ControlNets are specified in init, you can set the | ||
corresponding scale as a list. | ||
Examples: | ||
|
||
Returns: | ||
|
@@ -671,13 +807,21 @@ def __call__( | |
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" | ||
(nsfw) content, according to the `safety_checker`. | ||
""" | ||
|
||
# prepare `images` and `controlnet_conditioning_scale` for both a ControlNet and Multi-Controlnet | ||
# `images` here is a list where each element is a conditioning image for each ControlNet. | ||
images = self._prepare_images(image) | ||
controlnet_conditioning_scale = self._prepare_controlnet_conditioning_scale(controlnet_conditioning_scale) | ||
|
||
# 0. Default height and width to unet | ||
height, width = self._default_height_width(height, width, image) | ||
height, width = self._default_height_width(height, width, images[0]) | ||
|
||
# 1. Check inputs. Raise error if not correct | ||
self.check_inputs( | ||
prompt, image, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds | ||
prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds | ||
) | ||
for image in images: | ||
self.check_image(image, prompt, prompt_embeds) | ||
|
||
# 2. Define call parameters | ||
if prompt is not None and isinstance(prompt, str): | ||
|
@@ -705,18 +849,19 @@ def __call__( | |
) | ||
|
||
# 4. Prepare image | ||
image = self.prepare_image( | ||
image, | ||
width, | ||
height, | ||
batch_size * num_images_per_prompt, | ||
num_images_per_prompt, | ||
device, | ||
self.controlnet.dtype, | ||
) | ||
|
||
if do_classifier_free_guidance: | ||
image = torch.cat([image] * 2) | ||
images = [ | ||
self.prepare_image( | ||
image=image, | ||
width=width, | ||
height=height, | ||
batch_size=batch_size * num_images_per_prompt, | ||
num_images_per_prompt=num_images_per_prompt, | ||
device=device, | ||
dtype=self.controlnet.dtype, | ||
do_classifier_free_guidance=do_classifier_free_guidance, | ||
) | ||
for image in images | ||
] | ||
|
||
# 5. Prepare timesteps | ||
self.scheduler.set_timesteps(num_inference_steps, device=device) | ||
|
@@ -746,20 +891,16 @@ def __call__( | |
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents | ||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | ||
|
||
# controlnet(s) inference | ||
down_block_res_samples, mid_block_res_sample = self.controlnet( | ||
latent_model_input, | ||
t, | ||
encoder_hidden_states=prompt_embeds, | ||
controlnet_cond=image, | ||
controlnet_cond=images[0] if len(images) == 1 else images, | ||
conditioning_scale=controlnet_conditioning_scale, | ||
return_dict=False, | ||
) | ||
|
||
down_block_res_samples = [ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like the clean-up here, but I think since we have to pay attention to this: https://github.com/huggingface/diffusers/pull/2627/files#r1132300008 we should maybe put all this directly in the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's a nice idea! I'll write it in that direction. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed in d1acef4 |
||
down_block_res_sample * controlnet_conditioning_scale | ||
for down_block_res_sample in down_block_res_samples | ||
] | ||
mid_block_res_sample *= controlnet_conditioning_scale | ||
|
||
# predict the noise residual | ||
noise_pred = self.unet( | ||
latent_model_input, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great!