Skip to content

Attention mask for Flux & SD3 #10044

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

rootonchair
Copy link
Contributor

What does this PR do?

Fixes #10025
Fixes #8673

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@rootonchair
Copy link
Contributor Author

@sayakpaul @yiyixuxu, how should I test this feature? modify the original flux pipeline?

@bghira
Copy link
Contributor

bghira commented Nov 30, 2024

refer to this flux transformer implementation for attn masking details

query = torch.cat([query, encoder_hidden_states_query_proj], dim=2)
key = torch.cat([key, encoder_hidden_states_key_proj], dim=2)
value = torch.cat([value, encoder_hidden_states_value_proj], dim=2)

hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
if attention_mask is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well I think what we want is not to have a specific implementation to apply attention_mask for flux, is just to allow it to pass down all the way from pipeline, to transformer and the to attention processor so user can experiment with a custom attention mask

cc @christopher5106 is what I described here something you have in mind?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and if you do have an specific implementation that we want to add to diffusers, maybe you can run some experiments to help us decide if it's meaningful

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's the encoder attention mask though just following convention of pixart and other DiT that rely on attention masking. masking the attention arbitrarily doesn't unlock new use cases, does it? if so, providing examples of those would be nice.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rootonchair yes feel free to modify pipeline/model to test, and provide the experiments results to us:)

how should I test this feature? modify the original flux pipeline?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @christopher5106 - would you be able to provide a use case? since it was the original ask

masking the attention arbitrarily doesn't unlock new use cases, does it? if so, providing examples of those would be nice.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rootonchair yes feel free to modify pipeline/model to test, and provide the experiments results to us:)

how should I test this feature? modify the original flux pipeline?

Sure, perhaps the simplest test would be passing a padded prompt

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

checking the softmax scores for padded positions.

@yiyixuxu yiyixuxu added the roadmap Add to current release roadmap label Dec 3, 2024
@rootonchair
Copy link
Contributor Author

rootonchair commented Dec 4, 2024

refer to this flux transformer implementation for attn masking details

This implementation is a little confuse for me. If the encoder attention mask is [1,1,1,0,0]
Should the attention mask of the joint attention process be (assume the sequence length of images is 0)

[
[1,1,1,0,0],
[1,1,1,0,0],
[1,1,1,0,0],
[0,0,0,0,0],
[0,0,0,0,0]
]

Simply broadcast the mask would result in

[
[1,1,1,0,0],
[1,1,1,0,0],
[1,1,1,0,0],
[1,1,1,0,0],
[1,1,1,0,0]
]

while the two last token are just padding token

[
[1,1,1,0,0],
[1,1,1,0,0],
[1,1,1,0,0],
[1,1,1,0,0], <--- query's pad token
[1,1,1,0,0] <--- query's pad token
]

@bghira
Copy link
Contributor

bghira commented Dec 4, 2024

no flux doesnt mask padding tokens. doing so harms prompt adherence

@rootonchair rootonchair marked this pull request as ready for review December 23, 2024 06:58
@rootonchair
Copy link
Contributor Author

rootonchair commented Dec 23, 2024

This is my result testing with Flux pipeline

import torch
from diffusers import FluxPipeline

pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
pipe.enable_sequential_cpu_offload()
pipe.vae.enable_slicing()
pipe.vae.enable_tiling()
pipe.to(torch.float16)


prompt = [
    "a tiny astronaut hatching from an egg on the moon", 
    "A cat holding a sign that says hello world"
]
attention_mask = pipe.tokenizer_2(
    prompt,
    padding="max_length",
    max_length=512,
    truncation=True,
    return_length=False,
    return_overflowing_tokens=False,
    return_tensors="pt",
).attention_mask
attention_mask = attention_mask.to(device="cuda", dtype=torch.float16)
out = pipe(
    prompt=prompt,
    guidance_scale=3.5,
    height=768,
    width=1360,
    num_inference_steps=50,
    joint_attention_kwargs = {'attention_mask': attention_mask},
    generator=torch.Generator(device="cuda").manual_seed(42)
).images
out[0].save("image.png")
out[1].save("image1.png")

Without mask
old_image
old_image1

With mask
new_image
new_image1

@bghira
Copy link
Contributor

bghira commented Dec 23, 2024

patch embed artifacts in the masked one.

@rootonchair
Copy link
Contributor Author

For SD3, it doesn't have the same effect. I will try to figure out why and come up with an test script soon

@rootonchair
Copy link
Contributor Author

For SD3 I use the below script:

import torch
from diffusers import StableDiffusion3Pipeline

pipe = StableDiffusion3Pipeline.from_pretrained(
    "stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16
)
pipe.enable_model_cpu_offload()

prompt = [
    "", # negative prompt for first image
    "", # negative prompt fort second image
    "smiling cartoon dog sits at a table, coffee mug on hand, as a room goes up in flames. “This is fine,” the dog assures himself.", 
    "A cat holding a sign that says hello world",
]
t5_attention_mask = pipe.tokenizer_3(
    prompt,
    padding="max_length",
    max_length=256,
    truncation=True,
    add_special_tokens=True,
    return_tensors="pt",
).attention_mask

attention_mask = t5_attention_mask 
attention_mask = attention_mask.to(device="cuda", dtype=torch.float16)

prompt = prompt[2:]
print(prompt)
image = pipe(
    prompt,
    joint_attention_kwargs={'attention_mask': attention_mask},
    generator = torch.Generator(device="cuda").manual_seed(42),
).images
image[0].save("new_sd_image.png")
image[1].save("new_sd_image1.png")

without mask
old_sd_image

old_sd_image1

with mask
new_sd_image

new_sd_image1

@rootonchair rootonchair requested a review from yiyixuxu December 24, 2024 05:34
@christopher5106
Copy link

it looks like there is a bug

@bghira
Copy link
Contributor

bghira commented Dec 24, 2024

yeah i'm not too sure what's happening here anymore. the attention_mask is ignoring image token inputs? it is just for text encoder attn mask now?

i thought the idea was to add encoder_attention_mask and image_attention_mask parameters and make them cooperate with joint_attention_kwargs attention_mask which would supercede those two.

the test should not change the outputs at all, and yours does, which indicate there is something wrong.

Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Jan 18, 2025
@IorHon
Copy link

IorHon commented Apr 10, 2025

Has anyone already gotten it working properly? I'm willing to pay for the work to anyone who can integrate it into Comfyui for me.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
roadmap Add to current release roadmap stale Issues that haven't received updates
Projects
Status: Future Release
Development

Successfully merging this pull request may close these issues.

Attention masks are missing in SD3 to mask out text padding tokens attention mask for transformer Flux
7 participants