Skip to content

attention mask for transformer Flux #10025

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
christopher5106 opened this issue Nov 26, 2024 · 19 comments · Fixed by #10122 · May be fixed by #10044 or #10053
Closed

attention mask for transformer Flux #10025

christopher5106 opened this issue Nov 26, 2024 · 19 comments · Fixed by #10122 · May be fixed by #10044 or #10053
Labels
bug Something isn't working

Comments

@christopher5106
Copy link

christopher5106 commented Nov 26, 2024

Describe the bug

Is it possible to get back the attention_mask argument in the flux attention processor

hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False,attn_mask=attention_mask)

https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py#L1910

in order to tweak things a bit ? otherwise the argument attention_mask is unused.

Thanks a lot

Reproduction

pip install diffusers

Logs

No response

System Info

Ubuntu

Who can help?

@yiyixuxu @sayakpaul @DN6 @asomoza

@christopher5106 christopher5106 added the bug Something isn't working label Nov 26, 2024
@christopher5106 christopher5106 changed the title attention mask for transformer attention mask for transformer Flux Nov 26, 2024
@sayakpaul
Copy link
Member

Cc: @yiyixuxu

@yiyixuxu
Copy link
Collaborator

hey @christopher5106
yes! do you want to open a PR?

@rootonchair
Copy link
Contributor

Hi @yiyixuxu, I am working on this issue and it seems like attention_mask is not being used by all the pipelines. Could you help me finding a case that an attention mask is being used and passed to attention processsor?

@sayakpaul
Copy link
Member

sayakpaul commented Nov 28, 2024

Thanks, @rootonchair! The reason why that is not the case is because the original Flux implementation doesn't really use any mask so the Flux related pipelines don't use them. So, if we were actually use the attention mask in the Flux attention processor, users will have to make sure to pass them accordingly in their implementations.

@bghira
Copy link
Contributor

bghira commented Nov 30, 2024

it was never done, so it can't be "brought back". you will have to figure out how to handle the image inputs as well. but this probably should not be done, since the original model wasn't trained with it and not even large-scale finetuning on 32x H100 has been able to reliably introduce attention masking.

finetuning from eg. simpletuner or kohya scripts offer the option to reintroduce the attn masking but no inference engine (comfyUI, Forge, A1111, Diffusers) support the attn masking, and no one has really complained about that yet.

the few finetunes that lightly introduced attn masking did not go long enough to have much of an impact from it.

@christopher5106
Copy link
Author

i implemented a door then you can do whatever you want with the extra args. it makes no sense not to accept that attention_mask argument to be linked to the attention mecanism

@bonlime
Copy link
Contributor

bonlime commented Dec 3, 2024

I can give 2 motivations for introducing attention masking:

  1. You have added Flux Redux which in default implementation only works as image variation method, if you try to combine it with text prompts it completely ignores text, but this can be easily fixed by passing attn masks to put less weight on image condition tokens, here is an example reference image on the right, scales are written on the bottom, prompt is "anime style drawing, dog jumping in the field, castle background", you can see that using scales ~0.06-0.08 achieves decent prompt following, but this must pass attn_mask to work
    image
    image

  2. There is a recent paper OminiControl that implements a reference based image generation, which requires passing scale through attn mask, even if you wouldn't add it to diffusers, there are people out there (me for example), who would like to use native diffusers components instead of copy-pasted ones

@bghira
Copy link
Contributor

bghira commented Dec 3, 2024

what do you mean scale the attention mask? it is binary, false/true and the masked positions have a -inf softmax score. pytorch requires the values be 0 or 1. there is no float mask support for performance reasons.

i was hoping OP would share their code and implementation details because it isnt quite lining up for me. i would still rather encoder attn masking be implemented fully and be more than just "a door".

@bonlime
Copy link
Contributor

bonlime commented Dec 3, 2024

it is binary, false/true and the masked positions have a -inf softmax score. pytorch requires the values be 0 or 1. there is no float mask support for performance reasons.

torch does support float mask which is treated as attention bias and added before softmax. In my implementation above bias is equal to torch.log(scale.clamp_min_(1e-5)) for example and only added to some parts of the attention

the problem is that you're trying to enable passing argument & adding some implementations for where it's used, why not to just fix the issue of "there is an argument that is not being forwarded correctly", not arguing on what's the "correct" way to have attn masking

@bghira
Copy link
Contributor

bghira commented Dec 3, 2024

if you wanted to just fix it you could copy-paste it but the discussion here is about what's best for the Diffusers project and all of its users, not just a select few

would also be nice to see some performance numbers for the use of sparse attention masks, which are notoriously slow.

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Dec 4, 2024

@bonlime would you be able to share your code? It looks super cool, we are working on supporting prompt for flux redux here too #10056

I think it is very easy to support passing a custom attention mask down, so we can support that first if there is a meaningful use case; cc @hlky here, let me know what you think too

as for a default implementation, happy to explore and support that too, but we need to see tests/experiments so it's a bigger decision to make, I think it's been working on #10044

@hlky
Copy link
Contributor

hlky commented Dec 4, 2024

Yes I think passing a custom attention mask through the pipeline is the easiest to support initially, sounds like that would work for both @christopher5106 and @bonlime

@bonlime
Copy link
Contributor

bonlime commented Dec 4, 2024

@yiyixuxu here is the code

# assuming we already got reference_embeds somehow
prompt_embeds = torch.cat([prompt_embeds, *reference_embeds], dim=1)
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=torch.bfloat16)
# hardcode for now since it doesn't really change. may need adjustment if you want to support
# multiple images as reference
cond_size = 729
full_attention_size = max_sequence_length + latents.size(1) + cond_size
attention_mask = torch.zeros(
    (full_attention_size, full_attention_size), device=latents.device, dtype=latents.dtype
)
reference_scale: float = 0.05 # example
bias = torch.log(
    torch.tensor(reference_scale, dtype=latents.dtype, device=latents.device).clamp(min=1e-5, max=1)
)
attention_mask[:, max_sequence_length : max_sequence_length + cond_size] = bias
# you need patched attn_processor to use this argument
joint_attention_kwargs=dict(attention_mask=attention_mask),

@bghira
Copy link
Contributor

bghira commented Dec 4, 2024

@yiyixuxu we have a fully functional attn mechanism implemented here which people have used to dedistill models, was ported into kohya's trainer scripts, and has been adapted for inference elsewhere. thanks to @AmericanPresidentJimmyCarter for writing this code. it has been tested, experimented with and used for months now. but we haven't upstreamed this code because of the issues mentioned in the discussion from last month.

@AmericanPresidentJimmyCarter
Copy link
Contributor

Ensuring that your implementation is backwards compatible with the one for LibreFLUX/SimpleTuner would be good. There have been a number of finetunes since it was released.

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Dec 4, 2024

oh thanks!
we can support both, I think @bonlime 's use case is valid too

@rootonchair is adding the mplementation you're referring to here no? #10044, no?

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Dec 4, 2024

@bonlime
the only change you need is adding passing attention_mask here, do I understand correct?

hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)

@AmericanPresidentJimmyCarter
Copy link
Contributor

Both scale and attn_mask are natively supported by scaled_dot_product_attention, so yeah, you should just need a simple way to pass them in.

@bonlime
Copy link
Contributor

bonlime commented Dec 4, 2024

@yiyixuxu correct, one line change to make this work

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment