Skip to content

Commit 94f2c48

Browse files
authored
[feat]Add strength in flux_fill pipeline (denoising strength for fluxfill) (#10603)
* [feat]add strength in flux_fill pipeline * Update src/diffusers/pipelines/flux/pipeline_flux_fill.py * Update src/diffusers/pipelines/flux/pipeline_flux_fill.py * Update src/diffusers/pipelines/flux/pipeline_flux_fill.py * [refactor] refactor after review * [fix] change comment * Apply style fixes * empty * fix * update prepare_latents from flux.img2img pipeline * style * Update src/diffusers/pipelines/flux/pipeline_flux_fill.py ---------
1 parent aabf8ce commit 94f2c48

File tree

1 file changed

+104
-35
lines changed

1 file changed

+104
-35
lines changed

Diff for: src/diffusers/pipelines/flux/pipeline_flux_fill.py

+104-35
Original file line numberDiff line numberDiff line change
@@ -224,11 +224,13 @@ def __init__(
224224
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
225225
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
226226
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
227-
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
228-
latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
227+
self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
228+
self.image_processor = VaeImageProcessor(
229+
vae_scale_factor=self.vae_scale_factor * 2, vae_latent_channels=self.latent_channels
230+
)
229231
self.mask_processor = VaeImageProcessor(
230232
vae_scale_factor=self.vae_scale_factor * 2,
231-
vae_latent_channels=latent_channels,
233+
vae_latent_channels=self.latent_channels,
232234
do_normalize=False,
233235
do_binarize=True,
234236
do_convert_grayscale=True,
@@ -493,10 +495,38 @@ def encode_prompt(
493495

494496
return prompt_embeds, pooled_prompt_embeds, text_ids
495497

498+
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image
499+
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
500+
if isinstance(generator, list):
501+
image_latents = [
502+
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
503+
for i in range(image.shape[0])
504+
]
505+
image_latents = torch.cat(image_latents, dim=0)
506+
else:
507+
image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
508+
509+
image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
510+
511+
return image_latents
512+
513+
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
514+
def get_timesteps(self, num_inference_steps, strength, device):
515+
# get the original timestep using init_timestep
516+
init_timestep = min(num_inference_steps * strength, num_inference_steps)
517+
518+
t_start = int(max(num_inference_steps - init_timestep, 0))
519+
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
520+
if hasattr(self.scheduler, "set_begin_index"):
521+
self.scheduler.set_begin_index(t_start * self.scheduler.order)
522+
523+
return timesteps, num_inference_steps - t_start
524+
496525
def check_inputs(
497526
self,
498527
prompt,
499528
prompt_2,
529+
strength,
500530
height,
501531
width,
502532
prompt_embeds=None,
@@ -507,6 +537,9 @@ def check_inputs(
507537
mask_image=None,
508538
masked_image_latents=None,
509539
):
540+
if strength < 0 or strength > 1:
541+
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
542+
510543
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
511544
logger.warning(
512545
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
@@ -624,9 +657,11 @@ def disable_vae_tiling(self):
624657
"""
625658
self.vae.disable_tiling()
626659

627-
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents
660+
# Copied from diffusers.pipelines.flux.pipeline_flux_img2img.FluxImg2ImgPipeline.prepare_latents
628661
def prepare_latents(
629662
self,
663+
image,
664+
timestep,
630665
batch_size,
631666
num_channels_latents,
632667
height,
@@ -636,28 +671,41 @@ def prepare_latents(
636671
generator,
637672
latents=None,
638673
):
674+
if isinstance(generator, list) and len(generator) != batch_size:
675+
raise ValueError(
676+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
677+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
678+
)
679+
639680
# VAE applies 8x compression on images but we must also account for packing which requires
640681
# latent height and width to be divisible by 2.
641682
height = 2 * (int(height) // (self.vae_scale_factor * 2))
642683
width = 2 * (int(width) // (self.vae_scale_factor * 2))
643-
644684
shape = (batch_size, num_channels_latents, height, width)
685+
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
645686

646687
if latents is not None:
647-
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
648688
return latents.to(device=device, dtype=dtype), latent_image_ids
649689

650-
if isinstance(generator, list) and len(generator) != batch_size:
690+
image = image.to(device=device, dtype=dtype)
691+
if image.shape[1] != self.latent_channels:
692+
image_latents = self._encode_vae_image(image=image, generator=generator)
693+
else:
694+
image_latents = image
695+
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
696+
# expand init_latents for batch_size
697+
additional_image_per_prompt = batch_size // image_latents.shape[0]
698+
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
699+
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
651700
raise ValueError(
652-
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
653-
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
701+
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
654702
)
703+
else:
704+
image_latents = torch.cat([image_latents], dim=0)
655705

656-
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
706+
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
707+
latents = self.scheduler.scale_noise(image_latents, timestep, noise)
657708
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
658-
659-
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
660-
661709
return latents, latent_image_ids
662710

663711
@property
@@ -687,6 +735,7 @@ def __call__(
687735
masked_image_latents: Optional[torch.FloatTensor] = None,
688736
height: Optional[int] = None,
689737
width: Optional[int] = None,
738+
strength: float = 1.0,
690739
num_inference_steps: int = 50,
691740
sigmas: Optional[List[float]] = None,
692741
guidance_scale: float = 30.0,
@@ -731,6 +780,12 @@ def __call__(
731780
The height in pixels of the generated image. This is set to 1024 by default for the best results.
732781
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
733782
The width in pixels of the generated image. This is set to 1024 by default for the best results.
783+
strength (`float`, *optional*, defaults to 1.0):
784+
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
785+
starting point and more noise is added the higher the `strength`. The number of denoising steps depends
786+
on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
787+
process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
788+
essentially ignores `image`.
734789
num_inference_steps (`int`, *optional*, defaults to 50):
735790
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
736791
expense of slower inference.
@@ -794,6 +849,7 @@ def __call__(
794849
self.check_inputs(
795850
prompt,
796851
prompt_2,
852+
strength,
797853
height,
798854
width,
799855
prompt_embeds=prompt_embeds,
@@ -809,6 +865,9 @@ def __call__(
809865
self._joint_attention_kwargs = joint_attention_kwargs
810866
self._interrupt = False
811867

868+
init_image = self.image_processor.preprocess(image, height=height, width=width)
869+
init_image = init_image.to(dtype=torch.float32)
870+
812871
# 2. Define call parameters
813872
if prompt is not None and isinstance(prompt, str):
814873
batch_size = 1
@@ -838,9 +897,37 @@ def __call__(
838897
lora_scale=lora_scale,
839898
)
840899

841-
# 4. Prepare latent variables
900+
# 4. Prepare timesteps
901+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
902+
image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
903+
mu = calculate_shift(
904+
image_seq_len,
905+
self.scheduler.config.get("base_image_seq_len", 256),
906+
self.scheduler.config.get("max_image_seq_len", 4096),
907+
self.scheduler.config.get("base_shift", 0.5),
908+
self.scheduler.config.get("max_shift", 1.15),
909+
)
910+
timesteps, num_inference_steps = retrieve_timesteps(
911+
self.scheduler,
912+
num_inference_steps,
913+
device,
914+
sigmas=sigmas,
915+
mu=mu,
916+
)
917+
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
918+
919+
if num_inference_steps < 1:
920+
raise ValueError(
921+
f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
922+
f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
923+
)
924+
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
925+
926+
# 5. Prepare latent variables
842927
num_channels_latents = self.vae.config.latent_channels
843928
latents, latent_image_ids = self.prepare_latents(
929+
init_image,
930+
latent_timestep,
844931
batch_size * num_images_per_prompt,
845932
num_channels_latents,
846933
height,
@@ -851,17 +938,16 @@ def __call__(
851938
latents,
852939
)
853940

854-
# 5. Prepare mask and masked image latents
941+
# 6. Prepare mask and masked image latents
855942
if masked_image_latents is not None:
856943
masked_image_latents = masked_image_latents.to(latents.device)
857944
else:
858-
image = self.image_processor.preprocess(image, height=height, width=width)
859945
mask_image = self.mask_processor.preprocess(mask_image, height=height, width=width)
860946

861-
masked_image = image * (1 - mask_image)
947+
masked_image = init_image * (1 - mask_image)
862948
masked_image = masked_image.to(device=device, dtype=prompt_embeds.dtype)
863949

864-
height, width = image.shape[-2:]
950+
height, width = init_image.shape[-2:]
865951
mask, masked_image_latents = self.prepare_mask_latents(
866952
mask_image,
867953
masked_image,
@@ -876,23 +962,6 @@ def __call__(
876962
)
877963
masked_image_latents = torch.cat((masked_image_latents, mask), dim=-1)
878964

879-
# 6. Prepare timesteps
880-
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
881-
image_seq_len = latents.shape[1]
882-
mu = calculate_shift(
883-
image_seq_len,
884-
self.scheduler.config.get("base_image_seq_len", 256),
885-
self.scheduler.config.get("max_image_seq_len", 4096),
886-
self.scheduler.config.get("base_shift", 0.5),
887-
self.scheduler.config.get("max_shift", 1.15),
888-
)
889-
timesteps, num_inference_steps = retrieve_timesteps(
890-
self.scheduler,
891-
num_inference_steps,
892-
device,
893-
sigmas=sigmas,
894-
mu=mu,
895-
)
896965
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
897966
self._num_timesteps = len(timesteps)
898967

0 commit comments

Comments
 (0)