Skip to content

Flux Control(Depth/Canny) + Inpaint #10192

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Dec 18, 2024

Conversation

affromero
Copy link
Contributor

What does this PR do?

The Flux pipeline for image inpainting using Flux-dev-Depth/Canny

import torch
from diffusers import FluxControlInpaintPipeline
from diffusers.models.transformers import FluxTransformer2DModel
from transformers import T5EncoderModel
from diffusers.utils import load_image, make_image_grid
from image_gen_aux import DepthPreprocessor # https://github.com/huggingface/image_gen_aux
from PIL import Image
import numpy as np

pipe = FluxControlInpaintPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-Depth-dev",
    torch_dtype=torch.bfloat16,
)
# use following lines if you have GPU constraints
# ---------------------------------------------------------------
transformer = FluxTransformer2DModel.from_pretrained(
    "sayakpaul/FLUX.1-Depth-dev-nf4", subfolder="transformer", torch_dtype=torch.bfloat16
)
text_encoder_2 = T5EncoderModel.from_pretrained(
    "sayakpaul/FLUX.1-Depth-dev-nf4", subfolder="text_encoder_2", torch_dtype=torch.bfloat16
)
pipe.transformer = transformer
pipe.text_encoder_2 = text_encoder_2
pipe.enable_model_cpu_offload()
# ---------------------------------------------------------------
pipe.to("cuda")

prompt = "a blue robot singing opera with human-like expressions"
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png")

head_mask = np.zeros_like(image)
head_mask[65:580,300:642] = 255
mask_image = Image.fromarray(head_mask)

processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf")
control_image = processor(image)[0].convert("RGB")

output = pipe(
    prompt=prompt,
    image=image,
    control_image=control_image,
    mask_image=mask_image,
    num_inference_steps=30,
    strength=0.9,
    guidance_scale=50.0,
    generator=torch.Generator().manual_seed(42),
).images[0]
make_image_grid([image, control_image, mask_image, output.resize(image.size)], rows=1, cols=4).save("output.png")

output_flux_depth_inpaint

Before submitting

Who can review?

@sayakpaul @yiyixuxu

Following this interaction about contributing this pipeline to main.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for starting this! I have left some comments, LMK if they make sense.

else:
masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator)

masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could allow passing precomputed masked_image_latents but this is not a blocker for this PR.

Comment on lines 929 to 951
if masked_image_latents is not None:
masked_image_latents = masked_image_latents.to(latents.device)
else:
image = self.image_processor.preprocess(image, height=height, width=width)
mask_image = self.mask_processor.preprocess(mask_image, height=height, width=width)

masked_image = image * (1 - mask_image)
masked_image = masked_image.to(device=device, dtype=prompt_embeds.dtype)

height, width = image.shape[-2:]
mask, masked_image_latents = self.prepare_mask_latents(
mask_image,
masked_image,
batch_size,
num_channels_latents,
num_images_per_prompt,
height,
width,
prompt_embeds.dtype,
device,
generator,
)
masked_image_latents = torch.cat((masked_image_latents, mask), dim=-1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of this could go to prepare_mask_latents.

@affromero
Copy link
Contributor Author

I believe I've addressed all your comments, @sayakpaul @hlky. Please let me know if there's anything I might have missed.

Copy link
Contributor

@hlky hlky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a suggestion and a couple comments but good to go on my end. Thanks!

Comment on lines +123 to +137
# def test_flux_different_prompts(self):
# pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)

# inputs = self.get_dummy_inputs(torch_device)
# output_same_prompt = pipe(**inputs).images[0]

# inputs = self.get_dummy_inputs(torch_device)
# inputs["prompt_2"] = "a different prompt"
# output_different_prompts = pipe(**inputs).images[0]

# max_diff = np.abs(output_same_prompt - output_different_prompts).max()

# # Outputs should be different here
# # For some reasons, they don't show large differences
# assert max_diff > 1e-6
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#10025 (comment)

Possibly related, here's the effect it had on Redux #10056 (comment)

@asomoza
Copy link
Member

asomoza commented Dec 13, 2024

it works ok, I still have to see if it works for some more real use cases, but it produces the expected results.

a yellow martian
output

Funny that it generated this "unknown logo" with this example ^^

image

@hlky
Copy link
Contributor

hlky commented Dec 13, 2024

@affromero here's a couple example PRs for documentation https://github.com/huggingface/diffusers/pull/10021/files https://github.com/huggingface/diffusers/pull/10131/files files under docs/. Ping @stevhliu if you need help with the docs or when they're ready for review.

@affromero
Copy link
Contributor Author

I think this should be good, but not 100% sure @stevhliu, let me know what I can tweak or do next.

@sayakpaul
Copy link
Member

Will let @hlky take care of the final merging.

@hlky
Copy link
Contributor

hlky commented Dec 18, 2024

Failing test is unrelated, good to go, thanks @affromero!

@hlky hlky merged commit 83709d5 into huggingface:main Dec 18, 2024
11 of 12 checks passed
@vladmandic
Copy link
Contributor

please update src/diffusers/pipelines/auto_pipeline.py as well.

@yiyixuxu @sayakpaul @asomoza can we make this somehow part of the mandatory template? 90% of the time when any img2img or inpaint pipeline is added, autopipeline mapping is NOT updated.

Foundsheep pushed a commit to Foundsheep/diffusers that referenced this pull request Dec 23, 2024
* flux_control_inpaint - failing test_flux_different_prompts

* removing test_flux_different_prompts?

* fix style

* fix from PR comments

* fix style

* reducing guidance_scale in demo

* Update src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py

Co-authored-by: hlky <[email protected]>

* make

* prepare_latents is not copied from

* update docs

* typos

---------

Co-authored-by: affromero <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: hlky <[email protected]>
sayakpaul added a commit that referenced this pull request Dec 23, 2024
* flux_control_inpaint - failing test_flux_different_prompts

* removing test_flux_different_prompts?

* fix style

* fix from PR comments

* fix style

* reducing guidance_scale in demo

* Update src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py

Co-authored-by: hlky <[email protected]>

* make

* prepare_latents is not copied from

* update docs

* typos

---------

Co-authored-by: affromero <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: hlky <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants