20
20
import numpy as np
21
21
import PIL .Image
22
22
import torch
23
+ import torch .nn .functional as F
23
24
from torch import nn
24
25
from transformers import CLIPImageProcessor , CLIPTextModel , CLIPTokenizer
25
26
@@ -579,9 +580,20 @@ def check_inputs(
579
580
)
580
581
581
582
# Check `image`
582
- if isinstance (self .controlnet , ControlNetModel ):
583
+ is_compiled = hasattr (F , "scaled_dot_product_attention" ) and isinstance (
584
+ self .controlnet , torch ._dynamo .eval_frame .OptimizedModule
585
+ )
586
+ if (
587
+ isinstance (self .controlnet , ControlNetModel )
588
+ or is_compiled
589
+ and isinstance (self .controlnet ._orig_mod , ControlNetModel )
590
+ ):
583
591
self .check_image (image , prompt , prompt_embeds )
584
- elif isinstance (self .controlnet , MultiControlNetModel ):
592
+ elif (
593
+ isinstance (self .controlnet , MultiControlNetModel )
594
+ or is_compiled
595
+ and isinstance (self .controlnet ._orig_mod , MultiControlNetModel )
596
+ ):
585
597
if not isinstance (image , list ):
586
598
raise TypeError ("For multiple controlnets: `image` must be type `list`" )
587
599
@@ -600,10 +612,18 @@ def check_inputs(
600
612
assert False
601
613
602
614
# Check `controlnet_conditioning_scale`
603
- if isinstance (self .controlnet , ControlNetModel ):
615
+ if (
616
+ isinstance (self .controlnet , ControlNetModel )
617
+ or is_compiled
618
+ and isinstance (self .controlnet ._orig_mod , ControlNetModel )
619
+ ):
604
620
if not isinstance (controlnet_conditioning_scale , float ):
605
621
raise TypeError ("For single controlnet: `controlnet_conditioning_scale` must be type `float`." )
606
- elif isinstance (self .controlnet , MultiControlNetModel ):
622
+ elif (
623
+ isinstance (self .controlnet , MultiControlNetModel )
624
+ or is_compiled
625
+ and isinstance (self .controlnet ._orig_mod , MultiControlNetModel )
626
+ ):
607
627
if isinstance (controlnet_conditioning_scale , list ):
608
628
if any (isinstance (i , list ) for i in controlnet_conditioning_scale ):
609
629
raise ValueError ("A single batch of multiple conditionings are supported at the moment." )
@@ -910,7 +930,14 @@ def __call__(
910
930
)
911
931
912
932
# 4. Prepare image
913
- if isinstance (self .controlnet , ControlNetModel ):
933
+ is_compiled = hasattr (F , "scaled_dot_product_attention" ) and isinstance (
934
+ self .controlnet , torch ._dynamo .eval_frame .OptimizedModule
935
+ )
936
+ if (
937
+ isinstance (self .controlnet , ControlNetModel )
938
+ or is_compiled
939
+ and isinstance (self .controlnet ._orig_mod , ControlNetModel )
940
+ ):
914
941
image = self .prepare_image (
915
942
image = image ,
916
943
width = width ,
@@ -922,7 +949,11 @@ def __call__(
922
949
do_classifier_free_guidance = do_classifier_free_guidance ,
923
950
guess_mode = guess_mode ,
924
951
)
925
- elif isinstance (self .controlnet , MultiControlNetModel ):
952
+ elif (
953
+ isinstance (self .controlnet , MultiControlNetModel )
954
+ or is_compiled
955
+ and isinstance (self .controlnet ._orig_mod , MultiControlNetModel )
956
+ ):
926
957
images = []
927
958
928
959
for image_ in image :
@@ -1006,15 +1037,16 @@ def __call__(
1006
1037
cross_attention_kwargs = cross_attention_kwargs ,
1007
1038
down_block_additional_residuals = down_block_res_samples ,
1008
1039
mid_block_additional_residual = mid_block_res_sample ,
1009
- ).sample
1040
+ return_dict = False ,
1041
+ )[0 ]
1010
1042
1011
1043
# perform guidance
1012
1044
if do_classifier_free_guidance :
1013
1045
noise_pred_uncond , noise_pred_text = noise_pred .chunk (2 )
1014
1046
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond )
1015
1047
1016
1048
# compute the previous noisy sample x_t -> x_t-1
1017
- latents = self .scheduler .step (noise_pred , t , latents , ** extra_step_kwargs ). prev_sample
1049
+ latents = self .scheduler .step (noise_pred , t , latents , ** extra_step_kwargs , return_dict = False )[ 0 ]
1018
1050
1019
1051
# call the callback, if provided
1020
1052
if i == len (timesteps ) - 1 or ((i + 1 ) > num_warmup_steps and (i + 1 ) % self .scheduler .order == 0 ):
0 commit comments