Skip to content

Commit d9b8adc

Browse files
takuma104patrickvonplatenwilliamberman
authored
Add support for Multi-ControlNet to StableDiffusionControlNetPipeline (#2627)
* support for List[ControlNetModel] on init() * Add to support for multiple ControlNetCondition * rename conditioning_scale to scale * scaling bugfix * Manually merge `MultiControlNet` #2621 Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py Co-authored-by: Patrick von Platen <[email protected]> * cleanups - don't expose ControlNetCondition - move scaling to ControlNetModel * make style error correct * remove ControlNetCondition to reduce code diff * refactoring image/cond_scale * add explain for `images` * Add docstrings * all fast-test passed * Add a slow test * nit * Apply suggestions from code review * small precision fix * nits MultiControlNet -> MultiControlNetModel - Matches existing naming a bit closer MultiControlNetModel inherit from model utils class - Don't have to re-write fp16 test Skip tests that save multi controlnet pipeline - Clearer than changing test body Don't auto-batch the number of input images to the number of controlnets. We generally like to require the user to pass the expected number of inputs. This simplifies the processing code a bit more Use existing image pre-processing code a bit more. We can rely on the existing image pre-processing code and keep the inference loop a bit simpler. --------- Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: William Berman <[email protected]>
1 parent 4ae54b3 commit d9b8adc

File tree

3 files changed

+384
-31
lines changed

3 files changed

+384
-31
lines changed

src/diffusers/models/controlnet.py

+5
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,7 @@ def forward(
389389
timestep: Union[torch.Tensor, float, int],
390390
encoder_hidden_states: torch.Tensor,
391391
controlnet_cond: torch.FloatTensor,
392+
conditioning_scale: float = 1.0,
392393
class_labels: Optional[torch.Tensor] = None,
393394
timestep_cond: Optional[torch.Tensor] = None,
394395
attention_mask: Optional[torch.Tensor] = None,
@@ -492,6 +493,10 @@ def forward(
492493

493494
mid_block_res_sample = self.controlnet_mid_block(sample)
494495

496+
# 6. scaling
497+
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
498+
mid_block_res_sample *= conditioning_scale
499+
495500
if not return_dict:
496501
return (down_block_res_samples, mid_block_res_sample)
497502

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py

+183-31
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,18 @@
1414

1515

1616
import inspect
17-
from typing import Any, Callable, Dict, List, Optional, Union
17+
import os
18+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
1819

1920
import numpy as np
2021
import PIL.Image
2122
import torch
23+
from torch import nn
2224
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
2325

2426
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
27+
from ...models.controlnet import ControlNetOutput
28+
from ...models.modeling_utils import ModelMixin
2529
from ...schedulers import KarrasDiffusionSchedulers
2630
from ...utils import (
2731
PIL_INTERPOLATION,
@@ -85,6 +89,63 @@
8589
"""
8690

8791

92+
class MultiControlNetModel(ModelMixin):
93+
r"""
94+
Multiple `ControlNetModel` wrapper class for Multi-ControlNet
95+
96+
This module is a wrapper for multiple instances of the `ControlNetModel`. The `forward()` API is designed to be
97+
compatible with `ControlNetModel`.
98+
99+
Args:
100+
controlnets (`List[ControlNetModel]`):
101+
Provides additional conditioning to the unet during the denoising process. You must set multiple
102+
`ControlNetModel` as a list.
103+
"""
104+
105+
def __init__(self, controlnets: Union[List[ControlNetModel], Tuple[ControlNetModel]]):
106+
super().__init__()
107+
self.nets = nn.ModuleList(controlnets)
108+
109+
def forward(
110+
self,
111+
sample: torch.FloatTensor,
112+
timestep: Union[torch.Tensor, float, int],
113+
encoder_hidden_states: torch.Tensor,
114+
controlnet_cond: List[torch.tensor],
115+
conditioning_scale: List[float],
116+
class_labels: Optional[torch.Tensor] = None,
117+
timestep_cond: Optional[torch.Tensor] = None,
118+
attention_mask: Optional[torch.Tensor] = None,
119+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
120+
return_dict: bool = True,
121+
) -> Union[ControlNetOutput, Tuple]:
122+
for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
123+
down_samples, mid_sample = controlnet(
124+
sample,
125+
timestep,
126+
encoder_hidden_states,
127+
image,
128+
scale,
129+
class_labels,
130+
timestep_cond,
131+
attention_mask,
132+
cross_attention_kwargs,
133+
return_dict,
134+
)
135+
136+
# merge samples
137+
if i == 0:
138+
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
139+
else:
140+
down_block_res_samples = [
141+
samples_prev + samples_curr
142+
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
143+
]
144+
mid_block_res_sample += mid_sample
145+
146+
return down_block_res_samples, mid_block_res_sample
147+
148+
88149
class StableDiffusionControlNetPipeline(DiffusionPipeline):
89150
r"""
90151
Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
@@ -103,8 +164,10 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline):
103164
Tokenizer of class
104165
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
105166
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
106-
controlnet ([`ControlNetModel`]):
107-
Provides additional conditioning to the unet during the denoising process
167+
controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):
168+
Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets
169+
as a list, the outputs from each ControlNet are added together to create one combined additional
170+
conditioning.
108171
scheduler ([`SchedulerMixin`]):
109172
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
110173
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
@@ -122,7 +185,7 @@ def __init__(
122185
text_encoder: CLIPTextModel,
123186
tokenizer: CLIPTokenizer,
124187
unet: UNet2DConditionModel,
125-
controlnet: ControlNetModel,
188+
controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
126189
scheduler: KarrasDiffusionSchedulers,
127190
safety_checker: StableDiffusionSafetyChecker,
128191
feature_extractor: CLIPFeatureExtractor,
@@ -146,6 +209,9 @@ def __init__(
146209
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
147210
)
148211

212+
if isinstance(controlnet, (list, tuple)):
213+
controlnet = MultiControlNetModel(controlnet)
214+
149215
self.register_modules(
150216
vae=vae,
151217
text_encoder=text_encoder,
@@ -432,6 +498,7 @@ def check_inputs(
432498
negative_prompt=None,
433499
prompt_embeds=None,
434500
negative_prompt_embeds=None,
501+
controlnet_conditioning_scale=1.0,
435502
):
436503
if height % 8 != 0 or width % 8 != 0:
437504
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
@@ -470,6 +537,41 @@ def check_inputs(
470537
f" {negative_prompt_embeds.shape}."
471538
)
472539

540+
# Check `image`
541+
542+
if isinstance(self.controlnet, ControlNetModel):
543+
self.check_image(image, prompt, prompt_embeds)
544+
elif isinstance(self.controlnet, MultiControlNetModel):
545+
if not isinstance(image, list):
546+
raise TypeError("For multiple controlnets: `image` must be type `list`")
547+
548+
if len(image) != len(self.controlnet.nets):
549+
raise ValueError(
550+
"For multiple controlnets: `image` must have the same length as the number of controlnets."
551+
)
552+
553+
for image_ in image:
554+
self.check_image(image_, prompt, prompt_embeds)
555+
else:
556+
assert False
557+
558+
# Check `controlnet_conditioning_scale`
559+
560+
if isinstance(self.controlnet, ControlNetModel):
561+
if not isinstance(controlnet_conditioning_scale, float):
562+
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
563+
elif isinstance(self.controlnet, MultiControlNetModel):
564+
if isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
565+
self.controlnet.nets
566+
):
567+
raise ValueError(
568+
"For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
569+
" the same length as the number of controlnets"
570+
)
571+
else:
572+
assert False
573+
574+
def check_image(self, image, prompt, prompt_embeds):
473575
image_is_pil = isinstance(image, PIL.Image.Image)
474576
image_is_tensor = isinstance(image, torch.Tensor)
475577
image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
@@ -501,7 +603,9 @@ def check_inputs(
501603
f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
502604
)
503605

504-
def prepare_image(self, image, width, height, batch_size, num_images_per_prompt, device, dtype):
606+
def prepare_image(
607+
self, image, width, height, batch_size, num_images_per_prompt, device, dtype, do_classifier_free_guidance
608+
):
505609
if not isinstance(image, torch.Tensor):
506610
if isinstance(image, PIL.Image.Image):
507611
image = [image]
@@ -529,6 +633,9 @@ def prepare_image(self, image, width, height, batch_size, num_images_per_prompt,
529633

530634
image = image.to(device=device, dtype=dtype)
531635

636+
if do_classifier_free_guidance:
637+
image = torch.cat([image] * 2)
638+
532639
return image
533640

534641
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
@@ -550,7 +657,10 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
550657
return latents
551658

552659
def _default_height_width(self, height, width, image):
553-
if isinstance(image, list):
660+
# NOTE: It is possible that a list of images have different
661+
# dimensions for each image, so just checking the first image
662+
# is not _exactly_ correct, but it is simple.
663+
while isinstance(image, list):
554664
image = image[0]
555665

556666
if height is None:
@@ -571,6 +681,18 @@ def _default_height_width(self, height, width, image):
571681

572682
return height, width
573683

684+
# override DiffusionPipeline
685+
def save_pretrained(
686+
self,
687+
save_directory: Union[str, os.PathLike],
688+
safe_serialization: bool = False,
689+
variant: Optional[str] = None,
690+
):
691+
if isinstance(self.controlnet, ControlNetModel):
692+
super().save_pretrained(save_directory, safe_serialization, variant)
693+
else:
694+
raise NotImplementedError("Currently, the `save_pretrained()` is not implemented for Multi-ControlNet.")
695+
574696
@torch.no_grad()
575697
@replace_example_docstring(EXAMPLE_DOC_STRING)
576698
def __call__(
@@ -593,7 +715,7 @@ def __call__(
593715
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
594716
callback_steps: int = 1,
595717
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
596-
controlnet_conditioning_scale: float = 1.0,
718+
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
597719
):
598720
r"""
599721
Function invoked when calling the pipeline for generation.
@@ -602,10 +724,14 @@ def __call__(
602724
prompt (`str` or `List[str]`, *optional*):
603725
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
604726
instead.
605-
image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]` or `List[PIL.Image.Image]`):
727+
image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`,
728+
`List[List[torch.FloatTensor]]`, or `List[List[PIL.Image.Image]]`):
606729
The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
607-
the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. PIL.Image.Image` can
608-
also be accepted as an image. The control image is automatically resized to fit the output image.
730+
the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can
731+
also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If
732+
height and/or width are passed, `image` is resized according to them. If multiple ControlNets are
733+
specified in init, images must be passed as a list such that each element of the list can be correctly
734+
batched for input to a single controlnet.
609735
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
610736
The height in pixels of the generated image.
611737
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
@@ -658,10 +784,10 @@ def __call__(
658784
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
659785
`self.processor` in
660786
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
661-
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
787+
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
662788
The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
663-
to the residual in the original unet.
664-
789+
to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
790+
corresponding scale as a list.
665791
Examples:
666792
667793
Returns:
@@ -676,7 +802,15 @@ def __call__(
676802

677803
# 1. Check inputs. Raise error if not correct
678804
self.check_inputs(
679-
prompt, image, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
805+
prompt,
806+
image,
807+
height,
808+
width,
809+
callback_steps,
810+
negative_prompt,
811+
prompt_embeds,
812+
negative_prompt_embeds,
813+
controlnet_conditioning_scale,
680814
)
681815

682816
# 2. Define call parameters
@@ -693,6 +827,9 @@ def __call__(
693827
# corresponds to doing no classifier free guidance.
694828
do_classifier_free_guidance = guidance_scale > 1.0
695829

830+
if isinstance(self.controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
831+
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(self.controlnet.nets)
832+
696833
# 3. Encode input prompt
697834
prompt_embeds = self._encode_prompt(
698835
prompt,
@@ -705,18 +842,37 @@ def __call__(
705842
)
706843

707844
# 4. Prepare image
708-
image = self.prepare_image(
709-
image,
710-
width,
711-
height,
712-
batch_size * num_images_per_prompt,
713-
num_images_per_prompt,
714-
device,
715-
self.controlnet.dtype,
716-
)
845+
if isinstance(self.controlnet, ControlNetModel):
846+
image = self.prepare_image(
847+
image=image,
848+
width=width,
849+
height=height,
850+
batch_size=batch_size * num_images_per_prompt,
851+
num_images_per_prompt=num_images_per_prompt,
852+
device=device,
853+
dtype=self.controlnet.dtype,
854+
do_classifier_free_guidance=do_classifier_free_guidance,
855+
)
856+
elif isinstance(self.controlnet, MultiControlNetModel):
857+
images = []
858+
859+
for image_ in image:
860+
image_ = self.prepare_image(
861+
image=image_,
862+
width=width,
863+
height=height,
864+
batch_size=batch_size * num_images_per_prompt,
865+
num_images_per_prompt=num_images_per_prompt,
866+
device=device,
867+
dtype=self.controlnet.dtype,
868+
do_classifier_free_guidance=do_classifier_free_guidance,
869+
)
717870

718-
if do_classifier_free_guidance:
719-
image = torch.cat([image] * 2)
871+
images.append(image_)
872+
873+
image = images
874+
else:
875+
assert False
720876

721877
# 5. Prepare timesteps
722878
self.scheduler.set_timesteps(num_inference_steps, device=device)
@@ -746,20 +902,16 @@ def __call__(
746902
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
747903
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
748904

905+
# controlnet(s) inference
749906
down_block_res_samples, mid_block_res_sample = self.controlnet(
750907
latent_model_input,
751908
t,
752909
encoder_hidden_states=prompt_embeds,
753910
controlnet_cond=image,
911+
conditioning_scale=controlnet_conditioning_scale,
754912
return_dict=False,
755913
)
756914

757-
down_block_res_samples = [
758-
down_block_res_sample * controlnet_conditioning_scale
759-
for down_block_res_sample in down_block_res_samples
760-
]
761-
mid_block_res_sample *= controlnet_conditioning_scale
762-
763915
# predict the noise residual
764916
noise_pred = self.unet(
765917
latent_model_input,

0 commit comments

Comments
 (0)