-
Notifications
You must be signed in to change notification settings - Fork 6k
attention mask for transformer Flux #10025
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
attention mask for transformer Flux #10025
Comments
Cc: @yiyixuxu |
hey @christopher5106 |
Hi @yiyixuxu, I am working on this issue and it seems like |
Thanks, @rootonchair! The reason why that is not the case is because the original Flux implementation doesn't really use any mask so the Flux related pipelines don't use them. So, if we were actually use the attention mask in the Flux attention processor, users will have to make sure to pass them accordingly in their implementations. |
it was never done, so it can't be "brought back". you will have to figure out how to handle the image inputs as well. but this probably should not be done, since the original model wasn't trained with it and not even large-scale finetuning on 32x H100 has been able to reliably introduce attention masking. finetuning from eg. simpletuner or kohya scripts offer the option to reintroduce the attn masking but no inference engine (comfyUI, Forge, A1111, Diffusers) support the attn masking, and no one has really complained about that yet. the few finetunes that lightly introduced attn masking did not go long enough to have much of an impact from it. |
i implemented a door then you can do whatever you want with the extra args. it makes no sense not to accept that attention_mask argument to be linked to the attention mecanism |
I can give 2 motivations for introducing attention masking:
|
what do you mean scale the attention mask? it is binary, false/true and the masked positions have a i was hoping OP would share their code and implementation details because it isnt quite lining up for me. i would still rather encoder attn masking be implemented fully and be more than just "a door". |
torch does support float mask which is treated as attention bias and added before softmax. In my implementation above bias is equal to the problem is that you're trying to enable passing argument & adding some implementations for where it's used, why not to just fix the issue of "there is an argument that is not being forwarded correctly", not arguing on what's the "correct" way to have attn masking |
if you wanted to just fix it you could copy-paste it but the discussion here is about what's best for the Diffusers project and all of its users, not just a select few would also be nice to see some performance numbers for the use of sparse attention masks, which are notoriously slow. |
@bonlime would you be able to share your code? It looks super cool, we are working on supporting prompt for flux redux here too #10056 I think it is very easy to support passing a custom attention mask down, so we can support that first if there is a meaningful use case; cc @hlky here, let me know what you think too as for a default implementation, happy to explore and support that too, but we need to see tests/experiments so it's a bigger decision to make, I think it's been working on #10044 |
Yes I think passing a custom attention mask through the pipeline is the easiest to support initially, sounds like that would work for both @christopher5106 and @bonlime |
@yiyixuxu here is the code # assuming we already got reference_embeds somehow
prompt_embeds = torch.cat([prompt_embeds, *reference_embeds], dim=1)
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=torch.bfloat16)
# hardcode for now since it doesn't really change. may need adjustment if you want to support
# multiple images as reference
cond_size = 729
full_attention_size = max_sequence_length + latents.size(1) + cond_size
attention_mask = torch.zeros(
(full_attention_size, full_attention_size), device=latents.device, dtype=latents.dtype
)
reference_scale: float = 0.05 # example
bias = torch.log(
torch.tensor(reference_scale, dtype=latents.dtype, device=latents.device).clamp(min=1e-5, max=1)
)
attention_mask[:, max_sequence_length : max_sequence_length + cond_size] = bias
# you need patched attn_processor to use this argument
joint_attention_kwargs=dict(attention_mask=attention_mask), |
@yiyixuxu we have a fully functional attn mechanism implemented here which people have used to dedistill models, was ported into kohya's trainer scripts, and has been adapted for inference elsewhere. thanks to @AmericanPresidentJimmyCarter for writing this code. it has been tested, experimented with and used for months now. but we haven't upstreamed this code because of the issues mentioned in the discussion from last month. |
Ensuring that your implementation is backwards compatible with the one for LibreFLUX/SimpleTuner would be good. There have been a number of finetunes since it was released. |
oh thanks! @rootonchair is adding the mplementation you're referring to here no? #10044, no? |
@bonlime
|
Both |
@yiyixuxu correct, one line change to make this work |
Describe the bug
Is it possible to get back the
attention_mask
argument in the flux attention processorhttps://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py#L1910
in order to tweak things a bit ? otherwise the argument
attention_mask
is unused.Thanks a lot
Reproduction
pip install diffusers
Logs
No response
System Info
Ubuntu
Who can help?
@yiyixuxu @sayakpaul @DN6 @asomoza
The text was updated successfully, but these errors were encountered: