20
20
import numpy as np
21
21
import PIL .Image
22
22
import torch
23
- from torch import device , nn
23
+ from torch import nn
24
24
from transformers import CLIPFeatureExtractor , CLIPTextModel , CLIPTokenizer
25
25
26
26
from ...models import AutoencoderKL , ControlNetModel , UNet2DConditionModel
27
27
from ...models .controlnet import ControlNetOutput
28
- from ...models .modeling_utils import get_parameter_device , get_parameter_dtype
28
+ from ...models .modeling_utils import ModelMixin
29
29
from ...schedulers import KarrasDiffusionSchedulers
30
30
from ...utils import (
31
31
PIL_INTERPOLATION ,
89
89
"""
90
90
91
91
92
- class MultiControlNet ( nn . Module ):
92
+ class MultiControlNetModel ( ModelMixin ):
93
93
r"""
94
94
Multiple `ControlNetModel` wrapper class for Multi-ControlNet
95
95
@@ -102,25 +102,10 @@ class MultiControlNet(nn.Module):
102
102
`ControlNetModel` as a list.
103
103
"""
104
104
105
- def __init__ (self , controlnets : List [ControlNetModel ]):
105
+ def __init__ (self , controlnets : Union [ List [ControlNetModel ], Tuple [ ControlNetModel ] ]):
106
106
super ().__init__ ()
107
107
self .nets = nn .ModuleList (controlnets )
108
108
109
- @property
110
- def device (self ) -> device :
111
- """
112
- `torch.device`: The device on which the module is (assuming that all the module parameters are on the same
113
- device).
114
- """
115
- return get_parameter_device (self )
116
-
117
- @property
118
- def dtype (self ) -> torch .dtype :
119
- """
120
- `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
121
- """
122
- return get_parameter_dtype (self )
123
-
124
109
def forward (
125
110
self ,
126
111
sample : torch .FloatTensor ,
@@ -180,8 +165,9 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline):
180
165
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
181
166
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
182
167
controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):
183
- Provides additional conditioning to the unet during the denoising process. You can set multiple
184
- `ControlNetModel` as a list.
168
+ Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets
169
+ as a list, the outputs from each ControlNet are added together to create one combined additional
170
+ conditioning.
185
171
scheduler ([`SchedulerMixin`]):
186
172
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
187
173
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
@@ -199,7 +185,7 @@ def __init__(
199
185
text_encoder : CLIPTextModel ,
200
186
tokenizer : CLIPTokenizer ,
201
187
unet : UNet2DConditionModel ,
202
- controlnet : Union [ControlNetModel , List [ControlNetModel ]],
188
+ controlnet : Union [ControlNetModel , List [ControlNetModel ], Tuple [ ControlNetModel ], MultiControlNetModel ],
203
189
scheduler : KarrasDiffusionSchedulers ,
204
190
safety_checker : StableDiffusionSafetyChecker ,
205
191
feature_extractor : CLIPFeatureExtractor ,
@@ -224,9 +210,7 @@ def __init__(
224
210
)
225
211
226
212
if isinstance (controlnet , (list , tuple )):
227
- controlnet = MultiControlNet (controlnet )
228
- else :
229
- controlnet = controlnet
213
+ controlnet = MultiControlNetModel (controlnet )
230
214
231
215
self .register_modules (
232
216
vae = vae ,
@@ -507,12 +491,14 @@ def prepare_extra_step_kwargs(self, generator, eta):
507
491
def check_inputs (
508
492
self ,
509
493
prompt ,
494
+ image ,
510
495
height ,
511
496
width ,
512
497
callback_steps ,
513
498
negative_prompt = None ,
514
499
prompt_embeds = None ,
515
500
negative_prompt_embeds = None ,
501
+ controlnet_conditioning_scale = 1.0 ,
516
502
):
517
503
if height % 8 != 0 or width % 8 != 0 :
518
504
raise ValueError (f"`height` and `width` have to be divisible by 8 but are { height } and { width } ." )
@@ -551,6 +537,40 @@ def check_inputs(
551
537
f" { negative_prompt_embeds .shape } ."
552
538
)
553
539
540
+ # Check `image`
541
+
542
+ if isinstance (self .controlnet , ControlNetModel ):
543
+ self .check_image (image , prompt , prompt_embeds )
544
+ elif isinstance (self .controlnet , MultiControlNetModel ):
545
+ if not isinstance (image , list ):
546
+ raise TypeError ("For multiple controlnets: `image` must be type `list`" )
547
+
548
+ if len (image ) != len (self .controlnet .nets ):
549
+ raise ValueError (
550
+ "For multiple controlnets: `image` must have the same length as the number of controlnets."
551
+ )
552
+
553
+ for image_ in image :
554
+ self .check_image (image_ , prompt , prompt_embeds )
555
+ else :
556
+ assert False
557
+
558
+ # Check `controlnet_conditioning_scale`
559
+
560
+ if isinstance (self .controlnet , ControlNetModel ):
561
+ if not isinstance (controlnet_conditioning_scale , float ):
562
+ raise TypeError ("For single controlnet: `controlnet_conditioning_scale` must be type `float`." )
563
+ elif isinstance (self .controlnet , MultiControlNetModel ):
564
+ if isinstance (controlnet_conditioning_scale , list ) and len (controlnet_conditioning_scale ) != len (
565
+ self .controlnet .nets
566
+ ):
567
+ raise ValueError (
568
+ "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
569
+ " the same length as the number of controlnets"
570
+ )
571
+ else :
572
+ assert False
573
+
554
574
def check_image (self , image , prompt , prompt_embeds ):
555
575
image_is_pil = isinstance (image , PIL .Image .Image )
556
576
image_is_tensor = isinstance (image , torch .Tensor )
@@ -637,7 +657,10 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
637
657
return latents
638
658
639
659
def _default_height_width (self , height , width , image ):
640
- if isinstance (image , list ):
660
+ # NOTE: It is possible that a list of images have different
661
+ # dimensions for each image, so just checking the first image
662
+ # is not _exactly_ correct, but it is simple.
663
+ while isinstance (image , list ):
641
664
image = image [0 ]
642
665
643
666
if height is None :
@@ -658,41 +681,6 @@ def _default_height_width(self, height, width, image):
658
681
659
682
return height , width
660
683
661
- def _prepare_images (self , image ):
662
- if isinstance (self .controlnet , ControlNetModel ):
663
- return [image ] # convert to array for internal use
664
- else : # Multi-Controlnet
665
- if not isinstance (image , list ):
666
- raise ValueError ("The `image` argument needs to be specified in a `list`." )
667
-
668
- num_controlnets = len (self .controlnet .nets )
669
- if len (image ) % num_controlnets != 0 :
670
- raise ValueError (
671
- "The length of the `image` argument list needs to be a multiple of the number of Multi-ControlNet."
672
- )
673
-
674
- image_per_control = len (image ) // num_controlnets
675
-
676
- # let's split images over controlnets
677
- return [image [i : i + image_per_control ] for i in range (0 , len (image ), image_per_control )]
678
-
679
- def _prepare_controlnet_conditioning_scale (self , controlnet_conditioning_scale ):
680
- if isinstance (self .controlnet , ControlNetModel ):
681
- if not isinstance (controlnet_conditioning_scale , float ):
682
- raise ValueError ("The `controlnet_conditioning_scale` argument needs to be specified as a `float`." )
683
- return controlnet_conditioning_scale
684
- else : # Multi-Controlnet
685
- num_controlnets = len (self .controlnet .nets )
686
- if isinstance (controlnet_conditioning_scale , list ):
687
- if len (controlnet_conditioning_scale ) != num_controlnets :
688
- raise ValueError (
689
- "The length of the `controlnet_conditioning_scale` list does not match the number of Multi-ControlNet. "
690
- "If specified in `list`, it needs to have the same length as the number of Multi-ControlNet."
691
- )
692
- else :
693
- controlnet_conditioning_scale = [controlnet_conditioning_scale ] * num_controlnets
694
- return controlnet_conditioning_scale
695
-
696
684
# override DiffusionPipeline
697
685
def save_pretrained (
698
686
self ,
@@ -736,12 +724,14 @@ def __call__(
736
724
prompt (`str` or `List[str]`, *optional*):
737
725
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
738
726
instead.
739
- image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]` or `List[PIL.Image.Image]`):
727
+ image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`,
728
+ `List[List[torch.FloatTensor]]`, or `List[List[PIL.Image.Image]]`):
740
729
The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
741
730
the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can
742
- also be accepted as an image. The control image is automatically resized to fit the output image. If
743
- multiple ControlNets are specified in init, you need to set the corresponding images in the form of a
744
- list of `List[torch.FloatTensor]` or `List[PIL.Image.Image]`.
731
+ also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If
732
+ height and/or width are passed, `image` is resized according to them. If multiple ControlNets are
733
+ specified in init, images must be passed as a list such that each element of the list can be correctly
734
+ batched for input to a single controlnet.
745
735
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
746
736
The height in pixels of the generated image.
747
737
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
@@ -807,21 +797,21 @@ def __call__(
807
797
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
808
798
(nsfw) content, according to the `safety_checker`.
809
799
"""
810
-
811
- # prepare `images` and `controlnet_conditioning_scale` for both a ControlNet and Multi-Controlnet
812
- # `images` here is a list where each element is a conditioning image for each ControlNet.
813
- images = self ._prepare_images (image )
814
- controlnet_conditioning_scale = self ._prepare_controlnet_conditioning_scale (controlnet_conditioning_scale )
815
-
816
800
# 0. Default height and width to unet
817
- height , width = self ._default_height_width (height , width , images [ 0 ] )
801
+ height , width = self ._default_height_width (height , width , image )
818
802
819
803
# 1. Check inputs. Raise error if not correct
820
804
self .check_inputs (
821
- prompt , height , width , callback_steps , negative_prompt , prompt_embeds , negative_prompt_embeds
805
+ prompt ,
806
+ image ,
807
+ height ,
808
+ width ,
809
+ callback_steps ,
810
+ negative_prompt ,
811
+ prompt_embeds ,
812
+ negative_prompt_embeds ,
813
+ controlnet_conditioning_scale ,
822
814
)
823
- for image in images :
824
- self .check_image (image , prompt , prompt_embeds )
825
815
826
816
# 2. Define call parameters
827
817
if prompt is not None and isinstance (prompt , str ):
@@ -837,6 +827,9 @@ def __call__(
837
827
# corresponds to doing no classifier free guidance.
838
828
do_classifier_free_guidance = guidance_scale > 1.0
839
829
830
+ if isinstance (self .controlnet , MultiControlNetModel ) and isinstance (controlnet_conditioning_scale , float ):
831
+ controlnet_conditioning_scale = [controlnet_conditioning_scale ] * len (self .controlnet .nets )
832
+
840
833
# 3. Encode input prompt
841
834
prompt_embeds = self ._encode_prompt (
842
835
prompt ,
@@ -849,8 +842,8 @@ def __call__(
849
842
)
850
843
851
844
# 4. Prepare image
852
- images = [
853
- self .prepare_image (
845
+ if isinstance ( self . controlnet , ControlNetModel ):
846
+ image = self .prepare_image (
854
847
image = image ,
855
848
width = width ,
856
849
height = height ,
@@ -860,8 +853,26 @@ def __call__(
860
853
dtype = self .controlnet .dtype ,
861
854
do_classifier_free_guidance = do_classifier_free_guidance ,
862
855
)
863
- for image in images
864
- ]
856
+ elif isinstance (self .controlnet , MultiControlNetModel ):
857
+ images = []
858
+
859
+ for image_ in image :
860
+ image_ = self .prepare_image (
861
+ image = image_ ,
862
+ width = width ,
863
+ height = height ,
864
+ batch_size = batch_size * num_images_per_prompt ,
865
+ num_images_per_prompt = num_images_per_prompt ,
866
+ device = device ,
867
+ dtype = self .controlnet .dtype ,
868
+ do_classifier_free_guidance = do_classifier_free_guidance ,
869
+ )
870
+
871
+ images .append (image_ )
872
+
873
+ image = images
874
+ else :
875
+ assert False
865
876
866
877
# 5. Prepare timesteps
867
878
self .scheduler .set_timesteps (num_inference_steps , device = device )
@@ -896,7 +907,7 @@ def __call__(
896
907
latent_model_input ,
897
908
t ,
898
909
encoder_hidden_states = prompt_embeds ,
899
- controlnet_cond = images [ 0 ] if len ( images ) == 1 else images ,
910
+ controlnet_cond = image ,
900
911
conditioning_scale = controlnet_conditioning_scale ,
901
912
return_dict = False ,
902
913
)
0 commit comments