-
Notifications
You must be signed in to change notification settings - Fork 6k
Attention mask for Flux & SD3 #10044
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@sayakpaul @yiyixuxu, how should I test this feature? modify the original flux pipeline? |
refer to this flux transformer implementation for attn masking details |
query = torch.cat([query, encoder_hidden_states_query_proj], dim=2) | ||
key = torch.cat([key, encoder_hidden_states_key_proj], dim=2) | ||
value = torch.cat([value, encoder_hidden_states_value_proj], dim=2) | ||
|
||
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) | ||
if attention_mask is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
well I think what we want is not to have a specific implementation to apply attention_mask for flux, is just to allow it to pass down all the way from pipeline, to transformer and the to attention processor so user can experiment with a custom attention mask
cc @christopher5106 is what I described here something you have in mind?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and if you do have an specific implementation that we want to add to diffusers, maybe you can run some experiments to help us decide if it's meaningful
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's the encoder attention mask though just following convention of pixart and other DiT that rely on attention masking. masking the attention arbitrarily doesn't unlock new use cases, does it? if so, providing examples of those would be nice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rootonchair yes feel free to modify pipeline/model to test, and provide the experiments results to us:)
how should I test this feature? modify the original flux pipeline?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @christopher5106 - would you be able to provide a use case? since it was the original ask
masking the attention arbitrarily doesn't unlock new use cases, does it? if so, providing examples of those would be nice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rootonchair yes feel free to modify pipeline/model to test, and provide the experiments results to us:)
how should I test this feature? modify the original flux pipeline?
Sure, perhaps the simplest test would be passing a padded prompt
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
checking the softmax scores for padded positions.
This implementation is a little confuse for me. If the encoder attention mask is
Simply broadcast the mask would result in
while the two last token are just padding token
|
no flux doesnt mask padding tokens. doing so harms prompt adherence |
This is my result testing with Flux pipeline import torch
from diffusers import FluxPipeline
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
pipe.enable_sequential_cpu_offload()
pipe.vae.enable_slicing()
pipe.vae.enable_tiling()
pipe.to(torch.float16)
prompt = [
"a tiny astronaut hatching from an egg on the moon",
"A cat holding a sign that says hello world"
]
attention_mask = pipe.tokenizer_2(
prompt,
padding="max_length",
max_length=512,
truncation=True,
return_length=False,
return_overflowing_tokens=False,
return_tensors="pt",
).attention_mask
attention_mask = attention_mask.to(device="cuda", dtype=torch.float16)
out = pipe(
prompt=prompt,
guidance_scale=3.5,
height=768,
width=1360,
num_inference_steps=50,
joint_attention_kwargs = {'attention_mask': attention_mask},
generator=torch.Generator(device="cuda").manual_seed(42)
).images
out[0].save("image.png")
out[1].save("image1.png") |
patch embed artifacts in the masked one. |
For SD3, it doesn't have the same effect. I will try to figure out why and come up with an test script soon |
For SD3 I use the below script: import torch
from diffusers import StableDiffusion3Pipeline
pipe = StableDiffusion3Pipeline.from_pretrained(
"stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16
)
pipe.enable_model_cpu_offload()
prompt = [
"", # negative prompt for first image
"", # negative prompt fort second image
"smiling cartoon dog sits at a table, coffee mug on hand, as a room goes up in flames. “This is fine,” the dog assures himself.",
"A cat holding a sign that says hello world",
]
t5_attention_mask = pipe.tokenizer_3(
prompt,
padding="max_length",
max_length=256,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
).attention_mask
attention_mask = t5_attention_mask
attention_mask = attention_mask.to(device="cuda", dtype=torch.float16)
prompt = prompt[2:]
print(prompt)
image = pipe(
prompt,
joint_attention_kwargs={'attention_mask': attention_mask},
generator = torch.Generator(device="cuda").manual_seed(42),
).images
image[0].save("new_sd_image.png")
image[1].save("new_sd_image1.png") |
it looks like there is a bug |
yeah i'm not too sure what's happening here anymore. the attention_mask is ignoring image token inputs? it is just for text encoder attn mask now? i thought the idea was to add encoder_attention_mask and image_attention_mask parameters and make them cooperate with joint_attention_kwargs attention_mask which would supercede those two. the test should not change the outputs at all, and yours does, which indicate there is something wrong. |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
Has anyone already gotten it working properly? I'm willing to pay for the work to anyone who can integrate it into Comfyui for me. |
What does this PR do?
Fixes #10025
Fixes #8673
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.