@@ -92,10 +92,8 @@ class ControlNetCondition:
92
92
def __init__ (
93
93
self ,
94
94
image : Union [torch .FloatTensor , PIL .Image .Image , List [torch .FloatTensor ], List [PIL .Image .Image ]],
95
- scale : float = 1.0 ,
96
95
):
97
96
self .image_original = image
98
- self .scale = scale
99
97
100
98
def _default_height_width (self , height , width , image ):
101
99
if isinstance (image , list ):
@@ -226,38 +224,28 @@ def forward(
226
224
sample : torch .FloatTensor ,
227
225
timestep : Union [torch .Tensor , float , int ],
228
226
encoder_hidden_states : torch .Tensor ,
229
- images : List [torch .tensor ],
230
- cond_scales : List [int ],
227
+ controlnet_cond : List [torch .tensor ],
228
+ conditioning_scale : List [float ],
231
229
class_labels : Optional [torch .Tensor ] = None ,
232
230
timestep_cond : Optional [torch .Tensor ] = None ,
233
231
attention_mask : Optional [torch .Tensor ] = None ,
234
232
cross_attention_kwargs : Optional [Dict [str , Any ]] = None ,
235
233
return_dict : bool = True ,
236
234
) -> Union [ControlNetOutput , Tuple ]:
237
- if len (controlnet_conditions ) != len (self .nets ):
238
- raise ValueError (
239
- f"The number of specified `ControlNetCondition` does not match the number of `ControlNetModel`."
240
- f"There are { len (self .nets )} ControlNetModel(s), "
241
- f"but there are { len (controlnet_conditions )} `ControlNetCondition` in `controlnet_conditions`."
242
- )
243
-
244
- for i , (image , scale , controlnet ) in enumerate (zip (images , scales , self .nets )):
235
+ for i , (image , scale , controlnet ) in enumerate (zip (controlnet_cond , conditioning_scale , self .nets )):
245
236
down_samples , mid_sample = controlnet (
246
237
sample ,
247
238
timestep ,
248
239
encoder_hidden_states ,
249
240
image ,
241
+ scale ,
250
242
class_labels ,
251
243
timestep_cond ,
252
244
attention_mask ,
253
245
cross_attention_kwargs ,
254
246
return_dict ,
255
247
)
256
248
257
- # scaling
258
- down_samples = [sample * cond .scale for sample in down_samples ]
259
- mid_sample *= cond .scale
260
-
261
249
# merge samples
262
250
if i == 0 :
263
251
down_block_res_samples , mid_block_res_sample = down_samples , mid_sample
@@ -700,8 +688,7 @@ def __call__(
700
688
callback : Optional [Callable [[int , int , torch .FloatTensor ], None ]] = None ,
701
689
callback_steps : int = 1 ,
702
690
cross_attention_kwargs : Optional [Dict [str , Any ]] = None ,
703
- controlnet_conditioning_scale : List [float ] = 1.0 ,
704
- controlnet_conditions : Optional [List [ControlNetCondition ]] = None ,
691
+ controlnet_conditioning_scale : Union [float , List [float ]] = 1.0 ,
705
692
):
706
693
r"""
707
694
Function invoked when calling the pipeline for generation.
@@ -780,18 +767,28 @@ def __call__(
780
767
(nsfw) content, according to the `safety_checker`.
781
768
"""
782
769
783
- # TODO: add conversion image array to ControlNetConditions
784
- if controlnet_conditions is None :
770
+ # TODO: refactoring
771
+ if isinstance (self .controlnet , ControlNetModel ):
772
+ controlnet_conditions = [ControlNetCondition (image = image )]
773
+ elif isinstance (self .controlnet , MultiControlNet ):
774
+ # TODO: refactoring
785
775
# let's split images over controlnets
786
- image_per_control = 1 if isinstance (self .controlnet , ControlNetModel ) else len (self .controlnet .nets )
787
- if image_per_control > 1 and not isinstance (image , list ):
788
- raise ValueError (...)
789
-
790
- if len (image ) % image_per_control != 0 :
791
- raise ValueError (...)
792
-
793
- images = [image [i :i + num_image_per_control ] for i in range (0 , len (image ), image_per_control )]
794
- controlnet_conditions = [ControlNetCondition (image = image , scale = scale ) for image , scale in zip (images , controlnet_conditioning_scale )]
776
+ # image_per_control = 1 if isinstance(self.controlnet, ControlNetModel) else len(self.controlnet.nets)
777
+ # if image_per_control > 1 and not isinstance(image, list):
778
+ # raise ValueError(...)
779
+
780
+ # if len(image) % image_per_control != 0:
781
+ # raise ValueError(...)
782
+
783
+ # if image_per_control > 1 and not isinstance(controlnet_conditioning_scale, list):
784
+ # controlnet_conditioning_scale = [controlnet_conditioning_scale] * image_per_control
785
+
786
+ # images = [image[i:i+image_per_control] for i in range(0, len(image), image_per_control)]
787
+
788
+ num_controlnets = len (self .controlnet .nets )
789
+ if num_controlnets > 1 and not isinstance (controlnet_conditioning_scale , list ):
790
+ controlnet_conditioning_scale = [controlnet_conditioning_scale ] * num_controlnets
791
+ controlnet_conditions = [ControlNetCondition (image = img ) for img in image ]
795
792
796
793
# 0. Default height and width to unet
797
794
height , width = controlnet_conditions [0 ].default_height_width (height , width )
@@ -829,14 +826,14 @@ def __call__(
829
826
)
830
827
831
828
# 4. Prepare image
832
- for cond , controlnet in zip ( controlnet_conditions , self . controlnet . nets ) :
829
+ for cond in controlnet_conditions :
833
830
cond .prepare_image (
834
831
width = width ,
835
832
height = height ,
836
833
batch_size = batch_size * num_images_per_prompt ,
837
834
num_images_per_prompt = num_images_per_prompt ,
838
835
device = device ,
839
- dtype = controlnet .dtype ,
836
+ dtype = self . controlnet .dtype ,
840
837
do_classifier_free_guidance = do_classifier_free_guidance ,
841
838
)
842
839
@@ -869,11 +866,16 @@ def __call__(
869
866
latent_model_input = self .scheduler .scale_model_input (latent_model_input , t )
870
867
871
868
# controlnet(s) inference
869
+ if len (controlnet_conditions ) > 1 :
870
+ controlnet_cond = [c .image for c in controlnet_conditions ]
871
+ else :
872
+ controlnet_cond = controlnet_conditions [0 ].image
872
873
down_block_res_samples , mid_block_res_sample = self .controlnet (
873
874
latent_model_input ,
874
875
t ,
875
876
encoder_hidden_states = prompt_embeds ,
876
- controlnet_conditions = controlnet_conditions ,
877
+ controlnet_cond = controlnet_cond ,
878
+ conditioning_scale = controlnet_conditioning_scale ,
877
879
return_dict = False ,
878
880
)
879
881
0 commit comments