@@ -224,11 +224,13 @@ def __init__(
224
224
self .vae_scale_factor = 2 ** (len (self .vae .config .block_out_channels ) - 1 ) if getattr (self , "vae" , None ) else 8
225
225
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
226
226
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
227
- self .image_processor = VaeImageProcessor (vae_scale_factor = self .vae_scale_factor * 2 )
228
- latent_channels = self .vae .config .latent_channels if getattr (self , "vae" , None ) else 16
227
+ self .latent_channels = self .vae .config .latent_channels if getattr (self , "vae" , None ) else 16
228
+ self .image_processor = VaeImageProcessor (
229
+ vae_scale_factor = self .vae_scale_factor * 2 , vae_latent_channels = self .latent_channels
230
+ )
229
231
self .mask_processor = VaeImageProcessor (
230
232
vae_scale_factor = self .vae_scale_factor * 2 ,
231
- vae_latent_channels = latent_channels ,
233
+ vae_latent_channels = self . latent_channels ,
232
234
do_normalize = False ,
233
235
do_binarize = True ,
234
236
do_convert_grayscale = True ,
@@ -493,10 +495,38 @@ def encode_prompt(
493
495
494
496
return prompt_embeds , pooled_prompt_embeds , text_ids
495
497
498
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image
499
+ def _encode_vae_image (self , image : torch .Tensor , generator : torch .Generator ):
500
+ if isinstance (generator , list ):
501
+ image_latents = [
502
+ retrieve_latents (self .vae .encode (image [i : i + 1 ]), generator = generator [i ])
503
+ for i in range (image .shape [0 ])
504
+ ]
505
+ image_latents = torch .cat (image_latents , dim = 0 )
506
+ else :
507
+ image_latents = retrieve_latents (self .vae .encode (image ), generator = generator )
508
+
509
+ image_latents = (image_latents - self .vae .config .shift_factor ) * self .vae .config .scaling_factor
510
+
511
+ return image_latents
512
+
513
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
514
+ def get_timesteps (self , num_inference_steps , strength , device ):
515
+ # get the original timestep using init_timestep
516
+ init_timestep = min (num_inference_steps * strength , num_inference_steps )
517
+
518
+ t_start = int (max (num_inference_steps - init_timestep , 0 ))
519
+ timesteps = self .scheduler .timesteps [t_start * self .scheduler .order :]
520
+ if hasattr (self .scheduler , "set_begin_index" ):
521
+ self .scheduler .set_begin_index (t_start * self .scheduler .order )
522
+
523
+ return timesteps , num_inference_steps - t_start
524
+
496
525
def check_inputs (
497
526
self ,
498
527
prompt ,
499
528
prompt_2 ,
529
+ strength ,
500
530
height ,
501
531
width ,
502
532
prompt_embeds = None ,
@@ -507,6 +537,9 @@ def check_inputs(
507
537
mask_image = None ,
508
538
masked_image_latents = None ,
509
539
):
540
+ if strength < 0 or strength > 1 :
541
+ raise ValueError (f"The value of strength should in [0.0, 1.0] but is { strength } " )
542
+
510
543
if height % (self .vae_scale_factor * 2 ) != 0 or width % (self .vae_scale_factor * 2 ) != 0 :
511
544
logger .warning (
512
545
f"`height` and `width` have to be divisible by { self .vae_scale_factor * 2 } but are { height } and { width } . Dimensions will be resized accordingly"
@@ -624,9 +657,11 @@ def disable_vae_tiling(self):
624
657
"""
625
658
self .vae .disable_tiling ()
626
659
627
- # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline .prepare_latents
660
+ # Copied from diffusers.pipelines.flux.pipeline_flux_img2img.FluxImg2ImgPipeline .prepare_latents
628
661
def prepare_latents (
629
662
self ,
663
+ image ,
664
+ timestep ,
630
665
batch_size ,
631
666
num_channels_latents ,
632
667
height ,
@@ -636,28 +671,41 @@ def prepare_latents(
636
671
generator ,
637
672
latents = None ,
638
673
):
674
+ if isinstance (generator , list ) and len (generator ) != batch_size :
675
+ raise ValueError (
676
+ f"You have passed a list of generators of length { len (generator )} , but requested an effective batch"
677
+ f" size of { batch_size } . Make sure the batch size matches the length of the generators."
678
+ )
679
+
639
680
# VAE applies 8x compression on images but we must also account for packing which requires
640
681
# latent height and width to be divisible by 2.
641
682
height = 2 * (int (height ) // (self .vae_scale_factor * 2 ))
642
683
width = 2 * (int (width ) // (self .vae_scale_factor * 2 ))
643
-
644
684
shape = (batch_size , num_channels_latents , height , width )
685
+ latent_image_ids = self ._prepare_latent_image_ids (batch_size , height // 2 , width // 2 , device , dtype )
645
686
646
687
if latents is not None :
647
- latent_image_ids = self ._prepare_latent_image_ids (batch_size , height // 2 , width // 2 , device , dtype )
648
688
return latents .to (device = device , dtype = dtype ), latent_image_ids
649
689
650
- if isinstance (generator , list ) and len (generator ) != batch_size :
690
+ image = image .to (device = device , dtype = dtype )
691
+ if image .shape [1 ] != self .latent_channels :
692
+ image_latents = self ._encode_vae_image (image = image , generator = generator )
693
+ else :
694
+ image_latents = image
695
+ if batch_size > image_latents .shape [0 ] and batch_size % image_latents .shape [0 ] == 0 :
696
+ # expand init_latents for batch_size
697
+ additional_image_per_prompt = batch_size // image_latents .shape [0 ]
698
+ image_latents = torch .cat ([image_latents ] * additional_image_per_prompt , dim = 0 )
699
+ elif batch_size > image_latents .shape [0 ] and batch_size % image_latents .shape [0 ] != 0 :
651
700
raise ValueError (
652
- f"You have passed a list of generators of length { len (generator )} , but requested an effective batch"
653
- f" size of { batch_size } . Make sure the batch size matches the length of the generators."
701
+ f"Cannot duplicate `image` of batch size { image_latents .shape [0 ]} to { batch_size } text prompts."
654
702
)
703
+ else :
704
+ image_latents = torch .cat ([image_latents ], dim = 0 )
655
705
656
- latents = randn_tensor (shape , generator = generator , device = device , dtype = dtype )
706
+ noise = randn_tensor (shape , generator = generator , device = device , dtype = dtype )
707
+ latents = self .scheduler .scale_noise (image_latents , timestep , noise )
657
708
latents = self ._pack_latents (latents , batch_size , num_channels_latents , height , width )
658
-
659
- latent_image_ids = self ._prepare_latent_image_ids (batch_size , height // 2 , width // 2 , device , dtype )
660
-
661
709
return latents , latent_image_ids
662
710
663
711
@property
@@ -687,6 +735,7 @@ def __call__(
687
735
masked_image_latents : Optional [torch .FloatTensor ] = None ,
688
736
height : Optional [int ] = None ,
689
737
width : Optional [int ] = None ,
738
+ strength : float = 1.0 ,
690
739
num_inference_steps : int = 50 ,
691
740
sigmas : Optional [List [float ]] = None ,
692
741
guidance_scale : float = 30.0 ,
@@ -731,6 +780,12 @@ def __call__(
731
780
The height in pixels of the generated image. This is set to 1024 by default for the best results.
732
781
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
733
782
The width in pixels of the generated image. This is set to 1024 by default for the best results.
783
+ strength (`float`, *optional*, defaults to 1.0):
784
+ Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
785
+ starting point and more noise is added the higher the `strength`. The number of denoising steps depends
786
+ on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
787
+ process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
788
+ essentially ignores `image`.
734
789
num_inference_steps (`int`, *optional*, defaults to 50):
735
790
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
736
791
expense of slower inference.
@@ -794,6 +849,7 @@ def __call__(
794
849
self .check_inputs (
795
850
prompt ,
796
851
prompt_2 ,
852
+ strength ,
797
853
height ,
798
854
width ,
799
855
prompt_embeds = prompt_embeds ,
@@ -809,6 +865,9 @@ def __call__(
809
865
self ._joint_attention_kwargs = joint_attention_kwargs
810
866
self ._interrupt = False
811
867
868
+ init_image = self .image_processor .preprocess (image , height = height , width = width )
869
+ init_image = init_image .to (dtype = torch .float32 )
870
+
812
871
# 2. Define call parameters
813
872
if prompt is not None and isinstance (prompt , str ):
814
873
batch_size = 1
@@ -838,9 +897,37 @@ def __call__(
838
897
lora_scale = lora_scale ,
839
898
)
840
899
841
- # 4. Prepare latent variables
900
+ # 4. Prepare timesteps
901
+ sigmas = np .linspace (1.0 , 1 / num_inference_steps , num_inference_steps ) if sigmas is None else sigmas
902
+ image_seq_len = (int (height ) // self .vae_scale_factor // 2 ) * (int (width ) // self .vae_scale_factor // 2 )
903
+ mu = calculate_shift (
904
+ image_seq_len ,
905
+ self .scheduler .config .get ("base_image_seq_len" , 256 ),
906
+ self .scheduler .config .get ("max_image_seq_len" , 4096 ),
907
+ self .scheduler .config .get ("base_shift" , 0.5 ),
908
+ self .scheduler .config .get ("max_shift" , 1.15 ),
909
+ )
910
+ timesteps , num_inference_steps = retrieve_timesteps (
911
+ self .scheduler ,
912
+ num_inference_steps ,
913
+ device ,
914
+ sigmas = sigmas ,
915
+ mu = mu ,
916
+ )
917
+ timesteps , num_inference_steps = self .get_timesteps (num_inference_steps , strength , device )
918
+
919
+ if num_inference_steps < 1 :
920
+ raise ValueError (
921
+ f"After adjusting the num_inference_steps by strength parameter: { strength } , the number of pipeline"
922
+ f"steps is { num_inference_steps } which is < 1 and not appropriate for this pipeline."
923
+ )
924
+ latent_timestep = timesteps [:1 ].repeat (batch_size * num_images_per_prompt )
925
+
926
+ # 5. Prepare latent variables
842
927
num_channels_latents = self .vae .config .latent_channels
843
928
latents , latent_image_ids = self .prepare_latents (
929
+ init_image ,
930
+ latent_timestep ,
844
931
batch_size * num_images_per_prompt ,
845
932
num_channels_latents ,
846
933
height ,
@@ -851,17 +938,16 @@ def __call__(
851
938
latents ,
852
939
)
853
940
854
- # 5 . Prepare mask and masked image latents
941
+ # 6 . Prepare mask and masked image latents
855
942
if masked_image_latents is not None :
856
943
masked_image_latents = masked_image_latents .to (latents .device )
857
944
else :
858
- image = self .image_processor .preprocess (image , height = height , width = width )
859
945
mask_image = self .mask_processor .preprocess (mask_image , height = height , width = width )
860
946
861
- masked_image = image * (1 - mask_image )
947
+ masked_image = init_image * (1 - mask_image )
862
948
masked_image = masked_image .to (device = device , dtype = prompt_embeds .dtype )
863
949
864
- height , width = image .shape [- 2 :]
950
+ height , width = init_image .shape [- 2 :]
865
951
mask , masked_image_latents = self .prepare_mask_latents (
866
952
mask_image ,
867
953
masked_image ,
@@ -876,23 +962,6 @@ def __call__(
876
962
)
877
963
masked_image_latents = torch .cat ((masked_image_latents , mask ), dim = - 1 )
878
964
879
- # 6. Prepare timesteps
880
- sigmas = np .linspace (1.0 , 1 / num_inference_steps , num_inference_steps ) if sigmas is None else sigmas
881
- image_seq_len = latents .shape [1 ]
882
- mu = calculate_shift (
883
- image_seq_len ,
884
- self .scheduler .config .get ("base_image_seq_len" , 256 ),
885
- self .scheduler .config .get ("max_image_seq_len" , 4096 ),
886
- self .scheduler .config .get ("base_shift" , 0.5 ),
887
- self .scheduler .config .get ("max_shift" , 1.15 ),
888
- )
889
- timesteps , num_inference_steps = retrieve_timesteps (
890
- self .scheduler ,
891
- num_inference_steps ,
892
- device ,
893
- sigmas = sigmas ,
894
- mu = mu ,
895
- )
896
965
num_warmup_steps = max (len (timesteps ) - num_inference_steps * self .scheduler .order , 0 )
897
966
self ._num_timesteps = len (timesteps )
898
967
0 commit comments