-
Notifications
You must be signed in to change notification settings - Fork 5.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support for cross-attention bias / mask #2634
Support for cross-attention bias / mask #2634
Conversation
The documentation is not available anymore as the PR was closed or merged. |
Do you have an example of just using a stock
The output:
|
@damian0815 your code looks good already; the only thing you're missing is a mask for uncond (because you have CFG enabled). I expect it'll work if you do this (I'm assuming uncond is the first condition in the batch): mask = torch.tensor([
[True]*77, # stable-diffusion was trained with an unmasked uncond, so sadly it learned to rely on those PAD embeds
[True] + [True]*9 + [True] + [False]*66, # mask out everything after <eos>
]) |
that did it, thanks. wow. this is significant. |
4fb67fb
to
9aea4a2
Compare
@@ -272,15 +278,22 @@ def prepare_attention_mask(self, attention_mask, target_length, batch_size=None) | |||
if attention_mask is None: | |||
return attention_mask | |||
|
|||
if attention_mask.shape[-1] != target_length: | |||
current_length: int = attention_mask.shape[-1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice!
Hey @Birch-san, Thanks a lot for the PR! I think there is a slight misunderstanding 😅 and I think my answer here: #1890 (comment) was not good. Note, I'm using I fully agree with what you said. SAI probably just didn't use the attention_mask because they forget it and also because it doesn't really make a big difference. Therefore, we also don't use it in diffusers/src/diffusers/models/unet_2d_blocks.py Line 1786 in 6766a81
=> Could you try to make changes to:
so that we pass the attention_mask and then we can see if we need to adapt it. Contrary to what I said before we should probably not adapt it inside the attention processors but before (e.g. here: diffusers/src/diffusers/models/attention.py Line 288 in 6766a81 It would be amazing if we could adapt this PR a bit to:
|
Please let me know if this doesn't make sense - super nice PR and investigation, would be super cool to get this in 🚀 |
I did originally have a stab doing it the way you proposed, but the hard part is that if you only have one |
That makes a lot of sense! I see exactly what you mean - Having looked a bit more into it, I think we've accidentally introduced a bug. I fixed it here: #2656 - think it should be much clearer now to add encodre attention mask functionality no? |
012dbce
to
41316d7
Compare
@patrickvonplaten I've reimplemented cross-attention bias in the way you've indicated. this time: minimal changes were needed to AttnProcessors. |
aa7cd2f
to
4232ad0
Compare
other than consistency with the existing kwarg i can see that in the lower levels a |
@damian0815 in the previous review iteration, the property was indeed named a bias, and rather than converting BoolTensor to a bias automatically: I left that choice to the user by exposing a I'd happily go back to that model, but @patrickvonplaten asked me to change. so now I'm trying to guess what'll get this through review. |
all right then haha |
…ver channels; we actually broadcast over query tokens
… masking myself. not comfortable explaining it, since I don't actually understand how a self-attn mask can work in its current form: the key length will be different in every ResBlock (we don't downsample the mask when we downsample the image).
… (only to resnet). remove from KCrossAttnDownBlock2D,KCrossAttnUpBlock2D#forward.
…k2D,UNetMidBlock2DSimpleCrossAttn, and mask-choice in those blocks' #forward, on the basis that they only do one type of attention, so the consumer can pass whichever type of attention_mask is appropriate.
…it enabled wasn't important, and it breaks the original unclip use-case). disable the test which was added.
…faces, to ensure consistency of forward interface.
…removal caused test failures. restore also the passing of emb to checkpointed calls to KAttentionBlock#forward.
…ttention_mask is None. this should fix UnCLIP compatibility.
8f12786
to
8c323c5
Compare
): | ||
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} | ||
|
||
if attention_mask is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
works for me! @williamberman @sayakpaul one final check here?
Thanks a mille for re-iterating here with us. We'll merge this PR and closely monitor the slow tests to be sure nothing broke |
* Cross-attention masks prefer qualified symbol, fix accidental Optional prefer qualified symbol in AttentionProcessor prefer qualified symbol in embeddings.py qualified symbol in transformed_2d qualify FloatTensor in unet_2d_blocks move new transformer_2d params attention_mask, encoder_attention_mask to the end of the section which is assumed (e.g. by functions such as checkpoint()) to have a stable positional param interface. regard return_dict as a special-case which is assumed to be injected separately from positional params (e.g. by create_custom_forward()). move new encoder_attention_mask param to end of CrossAttn block interfaces and Unet2DCondition interface, to maintain positional param interface. regenerate modeling_text_unet.py remove unused import unet_2d_condition encoder_attention_mask docs Co-authored-by: Pedro Cuenca <[email protected]> versatile_diffusion/modeling_text_unet.py encoder_attention_mask docs Co-authored-by: Pedro Cuenca <[email protected]> transformer_2d encoder_attention_mask docs Co-authored-by: Pedro Cuenca <[email protected]> unet_2d_blocks.py: add parameter name comments Co-authored-by: Pedro Cuenca <[email protected]> revert description. bool-to-bias treatment happens in unet_2d_condition only. comment parameter names fix copies, style * encoder_attention_mask for SimpleCrossAttnDownBlock2D, SimpleCrossAttnUpBlock2D * encoder_attention_mask for UNetMidBlock2DSimpleCrossAttn * support attention_mask, encoder_attention_mask in KCrossAttnDownBlock2D, KCrossAttnUpBlock2D, KAttentionBlock. fix binding of attention_mask, cross_attention_kwargs params in KCrossAttnDownBlock2D, KCrossAttnUpBlock2D checkpoint invocations. * fix mistake made during merge conflict resolution * regenerate versatile_diffusion * pass time embedding into checkpointed attention invocation * always assume encoder_attention_mask is a mask (i.e. not a bias). * style, fix-copies * add tests for cross-attention masks * add test for padding of attention mask * explain mask's query_tokens dim. fix explanation about broadcasting over channels; we actually broadcast over query tokens * support both masks and biases in Transformer2DModel#forward. document behaviour * fix-copies * delete attention_mask docs on the basis I never tested self-attention masking myself. not comfortable explaining it, since I don't actually understand how a self-attn mask can work in its current form: the key length will be different in every ResBlock (we don't downsample the mask when we downsample the image). * review feedback: the standard Unet blocks shouldn't pass temb to attn (only to resnet). remove from KCrossAttnDownBlock2D,KCrossAttnUpBlock2D#forward. * remove encoder_attention_mask param from SimpleCrossAttn{Up,Down}Block2D,UNetMidBlock2DSimpleCrossAttn, and mask-choice in those blocks' #forward, on the basis that they only do one type of attention, so the consumer can pass whichever type of attention_mask is appropriate. * put attention mask padding back to how it was (since the SD use-case it enabled wasn't important, and it breaks the original unclip use-case). disable the test which was added. * fix-copies * style * fix-copies * put encoder_attention_mask param back into Simple block forward interfaces, to ensure consistency of forward interface. * restore passing of emb to KAttentionBlock#forward, on the basis that removal caused test failures. restore also the passing of emb to checkpointed calls to KAttentionBlock#forward. * make simple unet2d blocks use encoder_attention_mask, but only when attention_mask is None. this should fix UnCLIP compatibility. * fix copies
* Cross-attention masks prefer qualified symbol, fix accidental Optional prefer qualified symbol in AttentionProcessor prefer qualified symbol in embeddings.py qualified symbol in transformed_2d qualify FloatTensor in unet_2d_blocks move new transformer_2d params attention_mask, encoder_attention_mask to the end of the section which is assumed (e.g. by functions such as checkpoint()) to have a stable positional param interface. regard return_dict as a special-case which is assumed to be injected separately from positional params (e.g. by create_custom_forward()). move new encoder_attention_mask param to end of CrossAttn block interfaces and Unet2DCondition interface, to maintain positional param interface. regenerate modeling_text_unet.py remove unused import unet_2d_condition encoder_attention_mask docs Co-authored-by: Pedro Cuenca <[email protected]> versatile_diffusion/modeling_text_unet.py encoder_attention_mask docs Co-authored-by: Pedro Cuenca <[email protected]> transformer_2d encoder_attention_mask docs Co-authored-by: Pedro Cuenca <[email protected]> unet_2d_blocks.py: add parameter name comments Co-authored-by: Pedro Cuenca <[email protected]> revert description. bool-to-bias treatment happens in unet_2d_condition only. comment parameter names fix copies, style * encoder_attention_mask for SimpleCrossAttnDownBlock2D, SimpleCrossAttnUpBlock2D * encoder_attention_mask for UNetMidBlock2DSimpleCrossAttn * support attention_mask, encoder_attention_mask in KCrossAttnDownBlock2D, KCrossAttnUpBlock2D, KAttentionBlock. fix binding of attention_mask, cross_attention_kwargs params in KCrossAttnDownBlock2D, KCrossAttnUpBlock2D checkpoint invocations. * fix mistake made during merge conflict resolution * regenerate versatile_diffusion * pass time embedding into checkpointed attention invocation * always assume encoder_attention_mask is a mask (i.e. not a bias). * style, fix-copies * add tests for cross-attention masks * add test for padding of attention mask * explain mask's query_tokens dim. fix explanation about broadcasting over channels; we actually broadcast over query tokens * support both masks and biases in Transformer2DModel#forward. document behaviour * fix-copies * delete attention_mask docs on the basis I never tested self-attention masking myself. not comfortable explaining it, since I don't actually understand how a self-attn mask can work in its current form: the key length will be different in every ResBlock (we don't downsample the mask when we downsample the image). * review feedback: the standard Unet blocks shouldn't pass temb to attn (only to resnet). remove from KCrossAttnDownBlock2D,KCrossAttnUpBlock2D#forward. * remove encoder_attention_mask param from SimpleCrossAttn{Up,Down}Block2D,UNetMidBlock2DSimpleCrossAttn, and mask-choice in those blocks' #forward, on the basis that they only do one type of attention, so the consumer can pass whichever type of attention_mask is appropriate. * put attention mask padding back to how it was (since the SD use-case it enabled wasn't important, and it breaks the original unclip use-case). disable the test which was added. * fix-copies * style * fix-copies * put encoder_attention_mask param back into Simple block forward interfaces, to ensure consistency of forward interface. * restore passing of emb to KAttentionBlock#forward, on the basis that removal caused test failures. restore also the passing of emb to checkpointed calls to KAttentionBlock#forward. * make simple unet2d blocks use encoder_attention_mask, but only when attention_mask is None. this should fix UnCLIP compatibility. * fix copies
* Cross-attention masks prefer qualified symbol, fix accidental Optional prefer qualified symbol in AttentionProcessor prefer qualified symbol in embeddings.py qualified symbol in transformed_2d qualify FloatTensor in unet_2d_blocks move new transformer_2d params attention_mask, encoder_attention_mask to the end of the section which is assumed (e.g. by functions such as checkpoint()) to have a stable positional param interface. regard return_dict as a special-case which is assumed to be injected separately from positional params (e.g. by create_custom_forward()). move new encoder_attention_mask param to end of CrossAttn block interfaces and Unet2DCondition interface, to maintain positional param interface. regenerate modeling_text_unet.py remove unused import unet_2d_condition encoder_attention_mask docs Co-authored-by: Pedro Cuenca <[email protected]> versatile_diffusion/modeling_text_unet.py encoder_attention_mask docs Co-authored-by: Pedro Cuenca <[email protected]> transformer_2d encoder_attention_mask docs Co-authored-by: Pedro Cuenca <[email protected]> unet_2d_blocks.py: add parameter name comments Co-authored-by: Pedro Cuenca <[email protected]> revert description. bool-to-bias treatment happens in unet_2d_condition only. comment parameter names fix copies, style * encoder_attention_mask for SimpleCrossAttnDownBlock2D, SimpleCrossAttnUpBlock2D * encoder_attention_mask for UNetMidBlock2DSimpleCrossAttn * support attention_mask, encoder_attention_mask in KCrossAttnDownBlock2D, KCrossAttnUpBlock2D, KAttentionBlock. fix binding of attention_mask, cross_attention_kwargs params in KCrossAttnDownBlock2D, KCrossAttnUpBlock2D checkpoint invocations. * fix mistake made during merge conflict resolution * regenerate versatile_diffusion * pass time embedding into checkpointed attention invocation * always assume encoder_attention_mask is a mask (i.e. not a bias). * style, fix-copies * add tests for cross-attention masks * add test for padding of attention mask * explain mask's query_tokens dim. fix explanation about broadcasting over channels; we actually broadcast over query tokens * support both masks and biases in Transformer2DModel#forward. document behaviour * fix-copies * delete attention_mask docs on the basis I never tested self-attention masking myself. not comfortable explaining it, since I don't actually understand how a self-attn mask can work in its current form: the key length will be different in every ResBlock (we don't downsample the mask when we downsample the image). * review feedback: the standard Unet blocks shouldn't pass temb to attn (only to resnet). remove from KCrossAttnDownBlock2D,KCrossAttnUpBlock2D#forward. * remove encoder_attention_mask param from SimpleCrossAttn{Up,Down}Block2D,UNetMidBlock2DSimpleCrossAttn, and mask-choice in those blocks' #forward, on the basis that they only do one type of attention, so the consumer can pass whichever type of attention_mask is appropriate. * put attention mask padding back to how it was (since the SD use-case it enabled wasn't important, and it breaks the original unclip use-case). disable the test which was added. * fix-copies * style * fix-copies * put encoder_attention_mask param back into Simple block forward interfaces, to ensure consistency of forward interface. * restore passing of emb to KAttentionBlock#forward, on the basis that removal caused test failures. restore also the passing of emb to checkpointed calls to KAttentionBlock#forward. * make simple unet2d blocks use encoder_attention_mask, but only when attention_mask is None. this should fix UnCLIP compatibility. * fix copies
Resolves #1890.
Why do we want masking?
In stable-diffusion, our prompt may be 10 tokens long, but CLIP segments are padded to 77 tokens long.
Attention is forced to attend over our useless PAD word embeddings.
This is incidentally a mistake that seeped into the core of stable-diffusion's training -- uncond is supposed to be 2 tokens long, but SD trained without a mask, attending accidentally to 75 extraneous PAD/EOS word embeddings.
This is also useful for CLIP stitching (splicing multiple CLIP segments together to attend to a longer context) -- you can avoid attending to the duplicate BOS embeddings that each CLIP segment contributes.
More generally: language transformers benefit from causal masks, to only attend to previous information in the sequence when predicting next tokens.
As per @patrickvonplaten's suggestion: I've implemented this in
CrossAttnProcessor
and friends.Strictly speaking, the suggestion was to make a new

CrossAttnProcessor
subclass. But this capability shouldn't be relegated to some "alternative" processor; masking is part of the definition of scaled dot product attention:There is pre-existing support for an
attention_mask
argument. It's underspecified though. Like, is it for self-attention? Cross-attention? Both?I found that for the UNet blocks used by stable-diffusion:
attention_mask
doesn't actually get passed as far asCrossAttention
.On the basis that
attention_mask
is underspecified: I made (as per @bonlime's design) a new parameter,encoder_attention_bias
, with a more limited responsibility. it's only for cross-attention masking. defining this is essential for knowing what sequence length it's supposed to have.I'm following the design precedent whereby "we don't know whether we're a self-attention or cross-attention layer, until we sniff at runtime whether
encoder_hidden_states is None
". personally I think that convention is a bit weird, but whatever.could be nice to have better separation of this, e.g. separate
SelfAttention
vsCrossAttention
classes, and separateself_attention_kwargs
vscross_attention_kwargs
.Rather than implementing a mask: I generalized the problem to implementing bias. Now we can boost
q_proj @ k_proj
dot products, rather than just masking them.I integrated against this new API in diffusers-play:
Birch-san/diffusers-play@c50b9bf
To construct a "bias" from a
BoolTensor
mask: I provide amask_to_bias
helper.If your CLIP conditions are:
[2, 77, 768]
Your
mask: BoolTensor
will be:[2, 77]
Your
bias: FloatTensor
will likewise be:[2, 77]
I've implemented for all the CrossAttnProcessors I know how to test:
XFormers was fun because most of its algorithms don't support a bias tensor, or require an A100, or require token length to be a multiple of 8.
If you want to use XFormers, you can pad your mask and CLIP embedding to a multiple of 8 like so:
Eventually I intend to contribute my pure-PyTorch memory-efficient processor, and I hope to add masking support to it at that time.
====
Examples:
I've taken a 74-token prompt, and masked it such that only two tokens survive (the BOS and EOS embeddings). Here's with and without the mask:

Flandre's identity is lost; only high-level semantic pooled into the EOS embedding survives.
Here I unmask just BOS, EOS, and Flandre's name:

Flandre's hair colour and clothing are corrected, but number of subjects is wrong.
Here I unmask BOS, EOS, Flandre's name, artist, number of subjects:

Number of subjects is now correct.