14
14
15
15
16
16
import inspect
17
- from typing import Any , Callable , Dict , List , Optional , Union
17
+ import os
18
+ from typing import Any , Callable , Dict , List , Optional , Tuple , Union
18
19
19
20
import numpy as np
20
21
import PIL .Image
21
22
import torch
23
+ from torch import nn
22
24
from transformers import CLIPFeatureExtractor , CLIPTextModel , CLIPTokenizer
23
25
24
26
from ...models import AutoencoderKL , ControlNetModel , UNet2DConditionModel
27
+ from ...models .controlnet import ControlNetOutput
28
+ from ...models .modeling_utils import ModelMixin
25
29
from ...schedulers import KarrasDiffusionSchedulers
26
30
from ...utils import (
27
31
PIL_INTERPOLATION ,
85
89
"""
86
90
87
91
92
+ class MultiControlNetModel (ModelMixin ):
93
+ r"""
94
+ Multiple `ControlNetModel` wrapper class for Multi-ControlNet
95
+
96
+ This module is a wrapper for multiple instances of the `ControlNetModel`. The `forward()` API is designed to be
97
+ compatible with `ControlNetModel`.
98
+
99
+ Args:
100
+ controlnets (`List[ControlNetModel]`):
101
+ Provides additional conditioning to the unet during the denoising process. You must set multiple
102
+ `ControlNetModel` as a list.
103
+ """
104
+
105
+ def __init__ (self , controlnets : Union [List [ControlNetModel ], Tuple [ControlNetModel ]]):
106
+ super ().__init__ ()
107
+ self .nets = nn .ModuleList (controlnets )
108
+
109
+ def forward (
110
+ self ,
111
+ sample : torch .FloatTensor ,
112
+ timestep : Union [torch .Tensor , float , int ],
113
+ encoder_hidden_states : torch .Tensor ,
114
+ controlnet_cond : List [torch .tensor ],
115
+ conditioning_scale : List [float ],
116
+ class_labels : Optional [torch .Tensor ] = None ,
117
+ timestep_cond : Optional [torch .Tensor ] = None ,
118
+ attention_mask : Optional [torch .Tensor ] = None ,
119
+ cross_attention_kwargs : Optional [Dict [str , Any ]] = None ,
120
+ return_dict : bool = True ,
121
+ ) -> Union [ControlNetOutput , Tuple ]:
122
+ for i , (image , scale , controlnet ) in enumerate (zip (controlnet_cond , conditioning_scale , self .nets )):
123
+ down_samples , mid_sample = controlnet (
124
+ sample ,
125
+ timestep ,
126
+ encoder_hidden_states ,
127
+ image ,
128
+ scale ,
129
+ class_labels ,
130
+ timestep_cond ,
131
+ attention_mask ,
132
+ cross_attention_kwargs ,
133
+ return_dict ,
134
+ )
135
+
136
+ # merge samples
137
+ if i == 0 :
138
+ down_block_res_samples , mid_block_res_sample = down_samples , mid_sample
139
+ else :
140
+ down_block_res_samples = [
141
+ samples_prev + samples_curr
142
+ for samples_prev , samples_curr in zip (down_block_res_samples , down_samples )
143
+ ]
144
+ mid_block_res_sample += mid_sample
145
+
146
+ return down_block_res_samples , mid_block_res_sample
147
+
148
+
88
149
class StableDiffusionControlNetPipeline (DiffusionPipeline ):
89
150
r"""
90
151
Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
@@ -103,8 +164,10 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline):
103
164
Tokenizer of class
104
165
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
105
166
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
106
- controlnet ([`ControlNetModel`]):
107
- Provides additional conditioning to the unet during the denoising process
167
+ controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):
168
+ Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets
169
+ as a list, the outputs from each ControlNet are added together to create one combined additional
170
+ conditioning.
108
171
scheduler ([`SchedulerMixin`]):
109
172
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
110
173
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
@@ -122,7 +185,7 @@ def __init__(
122
185
text_encoder : CLIPTextModel ,
123
186
tokenizer : CLIPTokenizer ,
124
187
unet : UNet2DConditionModel ,
125
- controlnet : ControlNetModel ,
188
+ controlnet : Union [ ControlNetModel , List [ ControlNetModel ], Tuple [ ControlNetModel ], MultiControlNetModel ] ,
126
189
scheduler : KarrasDiffusionSchedulers ,
127
190
safety_checker : StableDiffusionSafetyChecker ,
128
191
feature_extractor : CLIPFeatureExtractor ,
@@ -146,6 +209,9 @@ def __init__(
146
209
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
147
210
)
148
211
212
+ if isinstance (controlnet , (list , tuple )):
213
+ controlnet = MultiControlNetModel (controlnet )
214
+
149
215
self .register_modules (
150
216
vae = vae ,
151
217
text_encoder = text_encoder ,
@@ -432,6 +498,7 @@ def check_inputs(
432
498
negative_prompt = None ,
433
499
prompt_embeds = None ,
434
500
negative_prompt_embeds = None ,
501
+ controlnet_conditioning_scale = 1.0 ,
435
502
):
436
503
if height % 8 != 0 or width % 8 != 0 :
437
504
raise ValueError (f"`height` and `width` have to be divisible by 8 but are { height } and { width } ." )
@@ -470,6 +537,41 @@ def check_inputs(
470
537
f" { negative_prompt_embeds .shape } ."
471
538
)
472
539
540
+ # Check `image`
541
+
542
+ if isinstance (self .controlnet , ControlNetModel ):
543
+ self .check_image (image , prompt , prompt_embeds )
544
+ elif isinstance (self .controlnet , MultiControlNetModel ):
545
+ if not isinstance (image , list ):
546
+ raise TypeError ("For multiple controlnets: `image` must be type `list`" )
547
+
548
+ if len (image ) != len (self .controlnet .nets ):
549
+ raise ValueError (
550
+ "For multiple controlnets: `image` must have the same length as the number of controlnets."
551
+ )
552
+
553
+ for image_ in image :
554
+ self .check_image (image_ , prompt , prompt_embeds )
555
+ else :
556
+ assert False
557
+
558
+ # Check `controlnet_conditioning_scale`
559
+
560
+ if isinstance (self .controlnet , ControlNetModel ):
561
+ if not isinstance (controlnet_conditioning_scale , float ):
562
+ raise TypeError ("For single controlnet: `controlnet_conditioning_scale` must be type `float`." )
563
+ elif isinstance (self .controlnet , MultiControlNetModel ):
564
+ if isinstance (controlnet_conditioning_scale , list ) and len (controlnet_conditioning_scale ) != len (
565
+ self .controlnet .nets
566
+ ):
567
+ raise ValueError (
568
+ "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
569
+ " the same length as the number of controlnets"
570
+ )
571
+ else :
572
+ assert False
573
+
574
+ def check_image (self , image , prompt , prompt_embeds ):
473
575
image_is_pil = isinstance (image , PIL .Image .Image )
474
576
image_is_tensor = isinstance (image , torch .Tensor )
475
577
image_is_pil_list = isinstance (image , list ) and isinstance (image [0 ], PIL .Image .Image )
@@ -501,7 +603,9 @@ def check_inputs(
501
603
f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: { image_batch_size } , prompt batch size: { prompt_batch_size } "
502
604
)
503
605
504
- def prepare_image (self , image , width , height , batch_size , num_images_per_prompt , device , dtype ):
606
+ def prepare_image (
607
+ self , image , width , height , batch_size , num_images_per_prompt , device , dtype , do_classifier_free_guidance
608
+ ):
505
609
if not isinstance (image , torch .Tensor ):
506
610
if isinstance (image , PIL .Image .Image ):
507
611
image = [image ]
@@ -529,6 +633,9 @@ def prepare_image(self, image, width, height, batch_size, num_images_per_prompt,
529
633
530
634
image = image .to (device = device , dtype = dtype )
531
635
636
+ if do_classifier_free_guidance :
637
+ image = torch .cat ([image ] * 2 )
638
+
532
639
return image
533
640
534
641
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
@@ -550,7 +657,10 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
550
657
return latents
551
658
552
659
def _default_height_width (self , height , width , image ):
553
- if isinstance (image , list ):
660
+ # NOTE: It is possible that a list of images have different
661
+ # dimensions for each image, so just checking the first image
662
+ # is not _exactly_ correct, but it is simple.
663
+ while isinstance (image , list ):
554
664
image = image [0 ]
555
665
556
666
if height is None :
@@ -571,6 +681,18 @@ def _default_height_width(self, height, width, image):
571
681
572
682
return height , width
573
683
684
+ # override DiffusionPipeline
685
+ def save_pretrained (
686
+ self ,
687
+ save_directory : Union [str , os .PathLike ],
688
+ safe_serialization : bool = False ,
689
+ variant : Optional [str ] = None ,
690
+ ):
691
+ if isinstance (self .controlnet , ControlNetModel ):
692
+ super ().save_pretrained (save_directory , safe_serialization , variant )
693
+ else :
694
+ raise NotImplementedError ("Currently, the `save_pretrained()` is not implemented for Multi-ControlNet." )
695
+
574
696
@torch .no_grad ()
575
697
@replace_example_docstring (EXAMPLE_DOC_STRING )
576
698
def __call__ (
@@ -593,7 +715,7 @@ def __call__(
593
715
callback : Optional [Callable [[int , int , torch .FloatTensor ], None ]] = None ,
594
716
callback_steps : int = 1 ,
595
717
cross_attention_kwargs : Optional [Dict [str , Any ]] = None ,
596
- controlnet_conditioning_scale : float = 1.0 ,
718
+ controlnet_conditioning_scale : Union [ float , List [ float ]] = 1.0 ,
597
719
):
598
720
r"""
599
721
Function invoked when calling the pipeline for generation.
@@ -602,10 +724,14 @@ def __call__(
602
724
prompt (`str` or `List[str]`, *optional*):
603
725
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
604
726
instead.
605
- image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]` or `List[PIL.Image.Image]`):
727
+ image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`,
728
+ `List[List[torch.FloatTensor]]`, or `List[List[PIL.Image.Image]]`):
606
729
The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
607
- the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. PIL.Image.Image` can
608
- also be accepted as an image. The control image is automatically resized to fit the output image.
730
+ the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can
731
+ also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If
732
+ height and/or width are passed, `image` is resized according to them. If multiple ControlNets are
733
+ specified in init, images must be passed as a list such that each element of the list can be correctly
734
+ batched for input to a single controlnet.
609
735
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
610
736
The height in pixels of the generated image.
611
737
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
@@ -658,10 +784,10 @@ def __call__(
658
784
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
659
785
`self.processor` in
660
786
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
661
- controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
787
+ controlnet_conditioning_scale (`float` or `List[float]` , *optional*, defaults to 1.0):
662
788
The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
663
- to the residual in the original unet.
664
-
789
+ to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
790
+ corresponding scale as a list.
665
791
Examples:
666
792
667
793
Returns:
@@ -676,7 +802,15 @@ def __call__(
676
802
677
803
# 1. Check inputs. Raise error if not correct
678
804
self .check_inputs (
679
- prompt , image , height , width , callback_steps , negative_prompt , prompt_embeds , negative_prompt_embeds
805
+ prompt ,
806
+ image ,
807
+ height ,
808
+ width ,
809
+ callback_steps ,
810
+ negative_prompt ,
811
+ prompt_embeds ,
812
+ negative_prompt_embeds ,
813
+ controlnet_conditioning_scale ,
680
814
)
681
815
682
816
# 2. Define call parameters
@@ -693,6 +827,9 @@ def __call__(
693
827
# corresponds to doing no classifier free guidance.
694
828
do_classifier_free_guidance = guidance_scale > 1.0
695
829
830
+ if isinstance (self .controlnet , MultiControlNetModel ) and isinstance (controlnet_conditioning_scale , float ):
831
+ controlnet_conditioning_scale = [controlnet_conditioning_scale ] * len (self .controlnet .nets )
832
+
696
833
# 3. Encode input prompt
697
834
prompt_embeds = self ._encode_prompt (
698
835
prompt ,
@@ -705,18 +842,37 @@ def __call__(
705
842
)
706
843
707
844
# 4. Prepare image
708
- image = self .prepare_image (
709
- image ,
710
- width ,
711
- height ,
712
- batch_size * num_images_per_prompt ,
713
- num_images_per_prompt ,
714
- device ,
715
- self .controlnet .dtype ,
716
- )
845
+ if isinstance (self .controlnet , ControlNetModel ):
846
+ image = self .prepare_image (
847
+ image = image ,
848
+ width = width ,
849
+ height = height ,
850
+ batch_size = batch_size * num_images_per_prompt ,
851
+ num_images_per_prompt = num_images_per_prompt ,
852
+ device = device ,
853
+ dtype = self .controlnet .dtype ,
854
+ do_classifier_free_guidance = do_classifier_free_guidance ,
855
+ )
856
+ elif isinstance (self .controlnet , MultiControlNetModel ):
857
+ images = []
858
+
859
+ for image_ in image :
860
+ image_ = self .prepare_image (
861
+ image = image_ ,
862
+ width = width ,
863
+ height = height ,
864
+ batch_size = batch_size * num_images_per_prompt ,
865
+ num_images_per_prompt = num_images_per_prompt ,
866
+ device = device ,
867
+ dtype = self .controlnet .dtype ,
868
+ do_classifier_free_guidance = do_classifier_free_guidance ,
869
+ )
717
870
718
- if do_classifier_free_guidance :
719
- image = torch .cat ([image ] * 2 )
871
+ images .append (image_ )
872
+
873
+ image = images
874
+ else :
875
+ assert False
720
876
721
877
# 5. Prepare timesteps
722
878
self .scheduler .set_timesteps (num_inference_steps , device = device )
@@ -746,20 +902,16 @@ def __call__(
746
902
latent_model_input = torch .cat ([latents ] * 2 ) if do_classifier_free_guidance else latents
747
903
latent_model_input = self .scheduler .scale_model_input (latent_model_input , t )
748
904
905
+ # controlnet(s) inference
749
906
down_block_res_samples , mid_block_res_sample = self .controlnet (
750
907
latent_model_input ,
751
908
t ,
752
909
encoder_hidden_states = prompt_embeds ,
753
910
controlnet_cond = image ,
911
+ conditioning_scale = controlnet_conditioning_scale ,
754
912
return_dict = False ,
755
913
)
756
914
757
- down_block_res_samples = [
758
- down_block_res_sample * controlnet_conditioning_scale
759
- for down_block_res_sample in down_block_res_samples
760
- ]
761
- mid_block_res_sample *= controlnet_conditioning_scale
762
-
763
915
# predict the noise residual
764
916
noise_pred = self .unet (
765
917
latent_model_input ,
0 commit comments