-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Add support for Multi-ControlNet to StableDiffusionControlNetPipeline #2627
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
52b7c69
95e4a12
3c2fd98
aae2fc1
b14b304
6b38c3a
dd96189
6eb98e0
9ca4aaa
32c8873
987567d
91affab
d1acef4
9f04578
fb39868
9ff7978
dca3b5e
22c8661
52171e3
fa2010e
4297328
2528494
a2140aa
039db1b
a257ed5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,14 +14,18 @@ | |
|
||
|
||
import inspect | ||
from typing import Any, Callable, Dict, List, Optional, Union | ||
import os | ||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union | ||
|
||
import numpy as np | ||
import PIL.Image | ||
import torch | ||
from torch import nn | ||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer | ||
|
||
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel | ||
from ...models.controlnet import ControlNetOutput | ||
from ...models.modeling_utils import ModelMixin | ||
from ...schedulers import KarrasDiffusionSchedulers | ||
from ...utils import ( | ||
PIL_INTERPOLATION, | ||
|
@@ -85,6 +89,63 @@ | |
""" | ||
|
||
|
||
class MultiControlNetModel(ModelMixin): | ||
r""" | ||
Multiple `ControlNetModel` wrapper class for Multi-ControlNet | ||
|
||
This module is a wrapper for multiple instances of the `ControlNetModel`. The `forward()` API is designed to be | ||
compatible with `ControlNetModel`. | ||
|
||
Args: | ||
controlnets (`List[ControlNetModel]`): | ||
Provides additional conditioning to the unet during the denoising process. You must set multiple | ||
`ControlNetModel` as a list. | ||
""" | ||
|
||
def __init__(self, controlnets: Union[List[ControlNetModel], Tuple[ControlNetModel]]): | ||
super().__init__() | ||
self.nets = nn.ModuleList(controlnets) | ||
|
||
def forward( | ||
self, | ||
sample: torch.FloatTensor, | ||
timestep: Union[torch.Tensor, float, int], | ||
encoder_hidden_states: torch.Tensor, | ||
controlnet_cond: List[torch.tensor], | ||
conditioning_scale: List[float], | ||
class_labels: Optional[torch.Tensor] = None, | ||
timestep_cond: Optional[torch.Tensor] = None, | ||
attention_mask: Optional[torch.Tensor] = None, | ||
cross_attention_kwargs: Optional[Dict[str, Any]] = None, | ||
return_dict: bool = True, | ||
) -> Union[ControlNetOutput, Tuple]: | ||
for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)): | ||
down_samples, mid_sample = controlnet( | ||
sample, | ||
timestep, | ||
encoder_hidden_states, | ||
image, | ||
scale, | ||
class_labels, | ||
timestep_cond, | ||
attention_mask, | ||
patrickvonplaten marked this conversation as resolved.
Show resolved
Hide resolved
|
||
cross_attention_kwargs, | ||
return_dict, | ||
) | ||
|
||
# merge samples | ||
if i == 0: | ||
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample | ||
else: | ||
down_block_res_samples = [ | ||
samples_prev + samples_curr | ||
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples) | ||
] | ||
mid_block_res_sample += mid_sample | ||
|
||
return down_block_res_samples, mid_block_res_sample | ||
|
||
|
||
class StableDiffusionControlNetPipeline(DiffusionPipeline): | ||
r""" | ||
Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. | ||
|
@@ -103,8 +164,10 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline): | |
Tokenizer of class | ||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). | ||
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. | ||
controlnet ([`ControlNetModel`]): | ||
Provides additional conditioning to the unet during the denoising process | ||
controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): | ||
Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets | ||
as a list, the outputs from each ControlNet are added together to create one combined additional | ||
conditioning. | ||
scheduler ([`SchedulerMixin`]): | ||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of | ||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. | ||
|
@@ -122,7 +185,7 @@ def __init__( | |
text_encoder: CLIPTextModel, | ||
tokenizer: CLIPTokenizer, | ||
unet: UNet2DConditionModel, | ||
controlnet: ControlNetModel, | ||
controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], | ||
scheduler: KarrasDiffusionSchedulers, | ||
safety_checker: StableDiffusionSafetyChecker, | ||
feature_extractor: CLIPFeatureExtractor, | ||
|
@@ -146,6 +209,9 @@ def __init__( | |
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." | ||
) | ||
|
||
if isinstance(controlnet, (list, tuple)): | ||
controlnet = MultiControlNetModel(controlnet) | ||
|
||
self.register_modules( | ||
vae=vae, | ||
text_encoder=text_encoder, | ||
|
@@ -432,6 +498,7 @@ def check_inputs( | |
negative_prompt=None, | ||
prompt_embeds=None, | ||
negative_prompt_embeds=None, | ||
controlnet_conditioning_scale=1.0, | ||
): | ||
if height % 8 != 0 or width % 8 != 0: | ||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") | ||
|
@@ -470,6 +537,41 @@ def check_inputs( | |
f" {negative_prompt_embeds.shape}." | ||
) | ||
|
||
# Check `image` | ||
|
||
if isinstance(self.controlnet, ControlNetModel): | ||
self.check_image(image, prompt, prompt_embeds) | ||
elif isinstance(self.controlnet, MultiControlNetModel): | ||
if not isinstance(image, list): | ||
raise TypeError("For multiple controlnets: `image` must be type `list`") | ||
|
||
if len(image) != len(self.controlnet.nets): | ||
raise ValueError( | ||
"For multiple controlnets: `image` must have the same length as the number of controlnets." | ||
) | ||
|
||
for image_ in image: | ||
self.check_image(image_, prompt, prompt_embeds) | ||
else: | ||
assert False | ||
|
||
# Check `controlnet_conditioning_scale` | ||
|
||
if isinstance(self.controlnet, ControlNetModel): | ||
if not isinstance(controlnet_conditioning_scale, float): | ||
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") | ||
elif isinstance(self.controlnet, MultiControlNetModel): | ||
if isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( | ||
self.controlnet.nets | ||
): | ||
raise ValueError( | ||
"For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" | ||
" the same length as the number of controlnets" | ||
) | ||
else: | ||
assert False | ||
|
||
def check_image(self, image, prompt, prompt_embeds): | ||
image_is_pil = isinstance(image, PIL.Image.Image) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Refactor looks good to me |
||
image_is_tensor = isinstance(image, torch.Tensor) | ||
image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) | ||
|
@@ -501,7 +603,9 @@ def check_inputs( | |
f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" | ||
) | ||
|
||
def prepare_image(self, image, width, height, batch_size, num_images_per_prompt, device, dtype): | ||
def prepare_image( | ||
self, image, width, height, batch_size, num_images_per_prompt, device, dtype, do_classifier_free_guidance | ||
): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Indeed! Since this PR is already closed, could you please open a new PR for it? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sure |
||
if not isinstance(image, torch.Tensor): | ||
if isinstance(image, PIL.Image.Image): | ||
image = [image] | ||
|
@@ -529,6 +633,9 @@ def prepare_image(self, image, width, height, batch_size, num_images_per_prompt, | |
|
||
image = image.to(device=device, dtype=dtype) | ||
|
||
if do_classifier_free_guidance: | ||
image = torch.cat([image] * 2) | ||
|
||
return image | ||
|
||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents | ||
|
@@ -550,7 +657,10 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype | |
return latents | ||
|
||
def _default_height_width(self, height, width, image): | ||
if isinstance(image, list): | ||
# NOTE: It is possible that a list of images have different | ||
# dimensions for each image, so just checking the first image | ||
# is not _exactly_ correct, but it is simple. | ||
while isinstance(image, list): | ||
image = image[0] | ||
|
||
if height is None: | ||
|
@@ -571,6 +681,18 @@ def _default_height_width(self, height, width, image): | |
|
||
return height, width | ||
|
||
# override DiffusionPipeline | ||
def save_pretrained( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's very clean - great! |
||
self, | ||
save_directory: Union[str, os.PathLike], | ||
safe_serialization: bool = False, | ||
variant: Optional[str] = None, | ||
): | ||
if isinstance(self.controlnet, ControlNetModel): | ||
super().save_pretrained(save_directory, safe_serialization, variant) | ||
else: | ||
raise NotImplementedError("Currently, the `save_pretrained()` is not implemented for Multi-ControlNet.") | ||
|
||
@torch.no_grad() | ||
@replace_example_docstring(EXAMPLE_DOC_STRING) | ||
def __call__( | ||
|
@@ -593,7 +715,7 @@ def __call__( | |
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, | ||
callback_steps: int = 1, | ||
cross_attention_kwargs: Optional[Dict[str, Any]] = None, | ||
controlnet_conditioning_scale: float = 1.0, | ||
controlnet_conditioning_scale: Union[float, List[float]] = 1.0, | ||
): | ||
r""" | ||
Function invoked when calling the pipeline for generation. | ||
|
@@ -602,10 +724,14 @@ def __call__( | |
prompt (`str` or `List[str]`, *optional*): | ||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. | ||
instead. | ||
image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]` or `List[PIL.Image.Image]`): | ||
image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, | ||
`List[List[torch.FloatTensor]]`, or `List[List[PIL.Image.Image]]`): | ||
The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If | ||
the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. PIL.Image.Image` can | ||
also be accepted as an image. The control image is automatically resized to fit the output image. | ||
the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can | ||
also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If | ||
height and/or width are passed, `image` is resized according to them. If multiple ControlNets are | ||
specified in init, images must be passed as a list such that each element of the list can be correctly | ||
batched for input to a single controlnet. | ||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): | ||
The height in pixels of the generated image. | ||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): | ||
|
@@ -658,10 +784,10 @@ def __call__( | |
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under | ||
`self.processor` in | ||
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). | ||
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): | ||
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): | ||
The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added | ||
to the residual in the original unet. | ||
|
||
to the residual in the original unet. If multiple ControlNets are specified in init, you can set the | ||
corresponding scale as a list. | ||
Examples: | ||
|
||
Returns: | ||
|
@@ -676,7 +802,15 @@ def __call__( | |
|
||
# 1. Check inputs. Raise error if not correct | ||
self.check_inputs( | ||
prompt, image, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds | ||
prompt, | ||
image, | ||
height, | ||
width, | ||
callback_steps, | ||
negative_prompt, | ||
prompt_embeds, | ||
negative_prompt_embeds, | ||
controlnet_conditioning_scale, | ||
) | ||
|
||
# 2. Define call parameters | ||
|
@@ -693,6 +827,9 @@ def __call__( | |
# corresponds to doing no classifier free guidance. | ||
do_classifier_free_guidance = guidance_scale > 1.0 | ||
|
||
if isinstance(self.controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): | ||
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(self.controlnet.nets) | ||
|
||
# 3. Encode input prompt | ||
prompt_embeds = self._encode_prompt( | ||
prompt, | ||
|
@@ -705,18 +842,37 @@ def __call__( | |
) | ||
|
||
# 4. Prepare image | ||
image = self.prepare_image( | ||
image, | ||
width, | ||
height, | ||
batch_size * num_images_per_prompt, | ||
num_images_per_prompt, | ||
device, | ||
self.controlnet.dtype, | ||
) | ||
if isinstance(self.controlnet, ControlNetModel): | ||
image = self.prepare_image( | ||
image=image, | ||
width=width, | ||
height=height, | ||
batch_size=batch_size * num_images_per_prompt, | ||
num_images_per_prompt=num_images_per_prompt, | ||
device=device, | ||
dtype=self.controlnet.dtype, | ||
do_classifier_free_guidance=do_classifier_free_guidance, | ||
) | ||
elif isinstance(self.controlnet, MultiControlNetModel): | ||
images = [] | ||
|
||
for image_ in image: | ||
image_ = self.prepare_image( | ||
image=image_, | ||
width=width, | ||
height=height, | ||
batch_size=batch_size * num_images_per_prompt, | ||
num_images_per_prompt=num_images_per_prompt, | ||
device=device, | ||
dtype=self.controlnet.dtype, | ||
do_classifier_free_guidance=do_classifier_free_guidance, | ||
) | ||
|
||
if do_classifier_free_guidance: | ||
image = torch.cat([image] * 2) | ||
images.append(image_) | ||
|
||
image = images | ||
else: | ||
assert False | ||
|
||
# 5. Prepare timesteps | ||
self.scheduler.set_timesteps(num_inference_steps, device=device) | ||
|
@@ -746,20 +902,16 @@ def __call__( | |
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents | ||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | ||
|
||
# controlnet(s) inference | ||
down_block_res_samples, mid_block_res_sample = self.controlnet( | ||
latent_model_input, | ||
t, | ||
encoder_hidden_states=prompt_embeds, | ||
controlnet_cond=image, | ||
conditioning_scale=controlnet_conditioning_scale, | ||
return_dict=False, | ||
) | ||
|
||
down_block_res_samples = [ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like the clean-up here, but I think since we have to pay attention to this: https://github.com/huggingface/diffusers/pull/2627/files#r1132300008 we should maybe put all this directly in the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's a nice idea! I'll write it in that direction. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed in d1acef4 |
||
down_block_res_sample * controlnet_conditioning_scale | ||
for down_block_res_sample in down_block_res_samples | ||
] | ||
mid_block_res_sample *= controlnet_conditioning_scale | ||
|
||
# predict the noise residual | ||
noise_pred = self.unet( | ||
latent_model_input, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great!