Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for cross-attention bias / mask #2634

Merged
merged 25 commits into from
May 22, 2023

Conversation

Birch-san
Copy link
Contributor

@Birch-san Birch-san commented Mar 10, 2023

Resolves #1890.

Why do we want masking?
In stable-diffusion, our prompt may be 10 tokens long, but CLIP segments are padded to 77 tokens long.
Attention is forced to attend over our useless PAD word embeddings.
This is incidentally a mistake that seeped into the core of stable-diffusion's training -- uncond is supposed to be 2 tokens long, but SD trained without a mask, attending accidentally to 75 extraneous PAD/EOS word embeddings.
This is also useful for CLIP stitching (splicing multiple CLIP segments together to attend to a longer context) -- you can avoid attending to the duplicate BOS embeddings that each CLIP segment contributes.
More generally: language transformers benefit from causal masks, to only attend to previous information in the sequence when predicting next tokens.

As per @patrickvonplaten's suggestion: I've implemented this in CrossAttnProcessor and friends.

Strictly speaking, the suggestion was to make a new CrossAttnProcessor subclass. But this capability shouldn't be relegated to some "alternative" processor; masking is part of the definition of scaled dot product attention:

There is pre-existing support for an attention_mask argument. It's underspecified though. Like, is it for self-attention? Cross-attention? Both?
I found that for the UNet blocks used by stable-diffusion: attention_mask doesn't actually get passed as far as CrossAttention.

On the basis that attention_mask is underspecified: I made (as per @bonlime's design) a new parameter, encoder_attention_bias, with a more limited responsibility. it's only for cross-attention masking. defining this is essential for knowing what sequence length it's supposed to have.

I'm following the design precedent whereby "we don't know whether we're a self-attention or cross-attention layer, until we sniff at runtime whether encoder_hidden_states is None". personally I think that convention is a bit weird, but whatever.
could be nice to have better separation of this, e.g. separate SelfAttention vs CrossAttention classes, and separate self_attention_kwargs vs cross_attention_kwargs.

Rather than implementing a mask: I generalized the problem to implementing bias. Now we can boost q_proj @ k_proj dot products, rather than just masking them.

I integrated against this new API in diffusers-play:
Birch-san/diffusers-play@c50b9bf

To construct a "bias" from a BoolTensor mask: I provide a mask_to_bias helper.

If your CLIP conditions are:
[2, 77, 768]

Your mask: BoolTensor will be:
[2, 77]

Your bias: FloatTensor will likewise be:
[2, 77]

I've implemented for all the CrossAttnProcessors I know how to test:

  • CrossAttnProcessor
  • AttnProcessor2_0
  • SlicedAttnProcessor
  • XFormersCrossAttnProcessor

XFormers was fun because most of its algorithms don't support a bias tensor, or require an A100, or require token length to be a multiple of 8.

If you want to use XFormers, you can pad your mask and CLIP embedding to a multiple of 8 like so:

# mask is a BoolTensor:
# [2, 77]
# embed is a FloatTensor:
# [2, 77, 768]

mask_length = mask.shape[-1]
extra_tokens_needed = 8 - (mask_length % 8)
# 0-pad mask to multiple of 8 tokens
mask = pad(mask, (0, extra_tokens_needed))
# replicate-pad embedding to multiple of 8 tokens (mask will hide the extra tokens)
embed = pad(embed, (0, 0, 0, extra_tokens_needed,), 'replicate')

Eventually I intend to contribute my pure-PyTorch memory-efficient processor, and I hope to add masking support to it at that time.

====

Examples:

I've taken a 74-token prompt, and masked it such that only two tokens survive (the BOS and EOS embeddings). Here's with and without the mask:

Flandre's identity is lost; only high-level semantic pooled into the EOS embedding survives.

Here I unmask just BOS, EOS, and Flandre's name:

Flandre's hair colour and clothing are corrected, but number of subjects is wrong.

Here I unmask BOS, EOS, Flandre's name, artist, number of subjects:

Number of subjects is now correct.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Mar 10, 2023

The documentation is not available anymore as the PR was closed or merged.

@bonlime
Copy link
Contributor

bonlime commented Mar 10, 2023

Great to see this PR! In my experience masking is crucial for short prompts to have good quality. it also helps with removing many deformed limbs.

Examples. 1. without masking 2. with

image

image

image
image
image
image

@damian0815
Copy link
Contributor

damian0815 commented Mar 10, 2023

Do you have an example of just using a stock StableDiffusionPipeline? following is failing with a tensor size mismatch error:

prompt = "a cat playing with a ball in the forest"
seed = 123
mask = torch.tensor([[True] + [True]*9 + [True] + [False]*66]) # mask out everything after <eos>
bias = mask_to_bias(mask, dtype=torch.float32)
print(bias.shape)

pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5").to(device)
fixed_seed_generator = Generator(device="cpu").manual_seed(seed)
image = pipeline(prompt=prompt, 
                 cross_attention_kwargs={'encoder_attention_bias': bias},
                 num_inference_steps=7, 
                 generator=fixed_seed_generator).images[0]

The output:

torch.Size([1, 77])
...
~/2.current/stablediffusion/diffusers/src/diffusers/models/cross_attention.py in __call__(self, attn, hidden_states, encoder_hidden_states, attention_mask, encoder_attention_bias)
    341         value = attn.head_to_batch_dim(value)
    342 
--> 343         attention_probs = attn.get_attention_scores(query, key, attention_mask)
    344         hidden_states = torch.bmm(attention_probs, value)
    345         hidden_states = attn.batch_to_head_dim(hidden_states)

~/2.current/stablediffusion/diffusers/src/diffusers/models/cross_attention.py in get_attention_scores(self, query, key, attention_mask)
    251             key.transpose(-1, -2),
    252             beta=beta,
--> 253             alpha=self.scale,
    254         )
    255 

RuntimeError: The expanded size of the tensor (16) must match the existing size (8) at non-singleton dimension 0.  Target sizes: [16, 4096, 77].  Tensor sizes: [8, 1, 77]

@Birch-san
Copy link
Contributor Author

@damian0815 your code looks good already; the only thing you're missing is a mask for uncond (because you have CFG enabled).

I expect it'll work if you do this (I'm assuming uncond is the first condition in the batch):

mask = torch.tensor([
  [True]*77, # stable-diffusion was trained with an unmasked uncond, so sadly it learned to rely on those PAD embeds
  [True] + [True]*9 + [True] + [False]*66, # mask out everything after <eos>
])

@damian0815
Copy link
Contributor

damian0815 commented Mar 10, 2023

that did it, thanks.

wow. this is significant.

@@ -272,15 +278,22 @@ def prepare_attention_mask(self, attention_mask, target_length, batch_size=None)
if attention_mask is None:
return attention_mask

if attention_mask.shape[-1] != target_length:
current_length: int = attention_mask.shape[-1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice!

@patrickvonplaten
Copy link
Contributor

Hey @Birch-san,

Thanks a lot for the PR! I think there is a slight misunderstanding 😅 and I think my answer here: #1890 (comment) was not good.

Note, I'm using atteniton_mask and encoder_attention_bias interchangeably in the following.

I fully agree with what you said. SAI probably just didn't use the attention_mask because they forget it and also because it doesn't really make a big difference. Therefore, we also don't use it in diffusers - for stable diffusion we will always disable the attention_mask for stable diffusion by default. However we should definitely allow users to make use of the attention_mask. This is already more or less possible. The only restriction is this here:

# TODO(Patrick, William) - attention mask is not used
. Here we just don't pass the attention_mask at the moment because SD doesn't do it, but we should pass it.
=> Could you try to make changes to:

so that we pass the attention_mask and then we can see if we need to adapt it. Contrary to what I said before we should probably not adapt it inside the attention processors but before (e.g. here:

) so that we don't have to change all the attention processors.

It would be amazing if we could adapt this PR a bit to:

  • Prepare the attention mask before passing it to the attention processors; Try to just have one attention_mask arg instead of having to add a new encoder_attention_bias arg
  • Not create a one-liner function to create the attention -10000 vector (this hurts readability too much IMO)

@patrickvonplaten
Copy link
Contributor

Please let me know if this doesn't make sense - super nice PR and investigation, would be super cool to get this in 🚀

@Birch-san
Copy link
Contributor Author

I did originally have a stab doing it the way you proposed, but the hard part is that if you only have one attention_mask param: there's no way to know whether the mask is intended to be used for self-attention or for cross-attention.

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Mar 13, 2023

That makes a lot of sense! I see exactly what you mean - Having looked a bit more into it, I think we've accidentally introduced a bug. I fixed it here: #2656 - think it should be much clearer now to add encodre attention mask functionality no?

@Birch-san Birch-san force-pushed the cross_attn_mask_3 branch 2 times, most recently from 012dbce to 41316d7 Compare March 29, 2023 21:38
@Birch-san
Copy link
Contributor Author

Birch-san commented Mar 29, 2023

@patrickvonplaten I've reimplemented cross-attention bias in the way you've indicated. this time: minimal changes were needed to AttnProcessors.

@Birch-san Birch-san force-pushed the cross_attn_mask_3 branch 2 times, most recently from aa7cd2f to 4232ad0 Compare March 29, 2023 21:48
@damian0815
Copy link
Contributor

damian0815 commented Mar 30, 2023

other than consistency with the existing kwarg attention_mask, is there a reason the kwarg encoder_attention_mask (which is actually interpreted as a bias if it's a FloatTensor) shouldn't be called encoder_attention_bias?

i can see that in the lower levels a BoolTensor gets converted to a FloatTensor but this won't help users understand what's going wrong if they pass a tensor like [1, 1, 1, 1, 0, 0, 0, 0, 0...] to the encoder_attention_mask kwarg.

@Birch-san
Copy link
Contributor Author

@damian0815 in the previous review iteration, the property was indeed named a bias, and rather than converting BoolTensor to a bias automatically: I left that choice to the user by exposing a mask_to_bias function.

I'd happily go back to that model, but @patrickvonplaten asked me to change. so now I'm trying to guess what'll get this through review.

@damian0815
Copy link
Contributor

all right then haha

Birch-san added 19 commits May 20, 2023 12:22
…ver channels; we actually broadcast over query tokens
… masking myself. not comfortable explaining it, since I don't actually understand how a self-attn mask can work in its current form: the key length will be different in every ResBlock (we don't downsample the mask when we downsample the image).
… (only to resnet). remove from KCrossAttnDownBlock2D,KCrossAttnUpBlock2D#forward.
…k2D,UNetMidBlock2DSimpleCrossAttn, and mask-choice in those blocks' #forward, on the basis that they only do one type of attention, so the consumer can pass whichever type of attention_mask is appropriate.
…it enabled wasn't important, and it breaks the original unclip use-case). disable the test which was added.
…faces, to ensure consistency of forward interface.
…removal caused test failures. restore also the passing of emb to checkpointed calls to KAttentionBlock#forward.
…ttention_mask is None. this should fix UnCLIP compatibility.
@Birch-san Birch-san force-pushed the cross_attn_mask_3 branch from 8f12786 to 8c323c5 Compare May 20, 2023 11:22
):
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}

if attention_mask is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

works for me! @williamberman @sayakpaul one final check here?

@patrickvonplaten
Copy link
Contributor

Thanks a mille for re-iterating here with us. We'll merge this PR and closely monitor the slow tests to be sure nothing broke

@patrickvonplaten patrickvonplaten merged commit 64bf5d3 into huggingface:main May 22, 2023
dg845 pushed a commit to dg845/diffusers that referenced this pull request May 23, 2023
* Cross-attention masks

prefer qualified symbol, fix accidental Optional

prefer qualified symbol in AttentionProcessor

prefer qualified symbol in embeddings.py

qualified symbol in transformed_2d

qualify FloatTensor in unet_2d_blocks

move new transformer_2d params attention_mask, encoder_attention_mask to the end of the section which is assumed (e.g. by functions such as checkpoint()) to have a stable positional param interface. regard return_dict as a special-case which is assumed to be injected separately from positional params (e.g. by create_custom_forward()).

move new encoder_attention_mask param to end of CrossAttn block interfaces and Unet2DCondition interface, to maintain positional param interface.

regenerate modeling_text_unet.py

remove unused import

unet_2d_condition encoder_attention_mask docs

Co-authored-by: Pedro Cuenca <[email protected]>

versatile_diffusion/modeling_text_unet.py encoder_attention_mask docs

Co-authored-by: Pedro Cuenca <[email protected]>

transformer_2d encoder_attention_mask docs

Co-authored-by: Pedro Cuenca <[email protected]>

unet_2d_blocks.py: add parameter name comments

Co-authored-by: Pedro Cuenca <[email protected]>

revert description. bool-to-bias treatment happens in unet_2d_condition only.

comment parameter names

fix copies, style

* encoder_attention_mask for SimpleCrossAttnDownBlock2D, SimpleCrossAttnUpBlock2D

* encoder_attention_mask for UNetMidBlock2DSimpleCrossAttn

* support attention_mask, encoder_attention_mask in KCrossAttnDownBlock2D, KCrossAttnUpBlock2D, KAttentionBlock. fix binding of attention_mask, cross_attention_kwargs params in KCrossAttnDownBlock2D, KCrossAttnUpBlock2D checkpoint invocations.

* fix mistake made during merge conflict resolution

* regenerate versatile_diffusion

* pass time embedding into checkpointed attention invocation

* always assume encoder_attention_mask is a mask (i.e. not a bias).

* style, fix-copies

* add tests for cross-attention masks

* add test for padding of attention mask

* explain mask's query_tokens dim. fix explanation about broadcasting over channels; we actually broadcast over query tokens

* support both masks and biases in Transformer2DModel#forward. document behaviour

* fix-copies

* delete attention_mask docs on the basis I never tested self-attention masking myself. not comfortable explaining it, since I don't actually understand how a self-attn mask can work in its current form: the key length will be different in every ResBlock (we don't downsample the mask when we downsample the image).

* review feedback: the standard Unet blocks shouldn't pass temb to attn (only to resnet). remove from KCrossAttnDownBlock2D,KCrossAttnUpBlock2D#forward.

* remove encoder_attention_mask param from SimpleCrossAttn{Up,Down}Block2D,UNetMidBlock2DSimpleCrossAttn, and mask-choice in those blocks' #forward, on the basis that they only do one type of attention, so the consumer can pass whichever type of attention_mask is appropriate.

* put attention mask padding back to how it was (since the SD use-case it enabled wasn't important, and it breaks the original unclip use-case). disable the test which was added.

* fix-copies

* style

* fix-copies

* put encoder_attention_mask param back into Simple block forward interfaces, to ensure consistency of forward interface.

* restore passing of emb to KAttentionBlock#forward, on the basis that removal caused test failures. restore also the passing of emb to checkpointed calls to KAttentionBlock#forward.

* make simple unet2d blocks use encoder_attention_mask, but only when attention_mask is None. this should fix UnCLIP compatibility.

* fix copies
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* Cross-attention masks

prefer qualified symbol, fix accidental Optional

prefer qualified symbol in AttentionProcessor

prefer qualified symbol in embeddings.py

qualified symbol in transformed_2d

qualify FloatTensor in unet_2d_blocks

move new transformer_2d params attention_mask, encoder_attention_mask to the end of the section which is assumed (e.g. by functions such as checkpoint()) to have a stable positional param interface. regard return_dict as a special-case which is assumed to be injected separately from positional params (e.g. by create_custom_forward()).

move new encoder_attention_mask param to end of CrossAttn block interfaces and Unet2DCondition interface, to maintain positional param interface.

regenerate modeling_text_unet.py

remove unused import

unet_2d_condition encoder_attention_mask docs

Co-authored-by: Pedro Cuenca <[email protected]>

versatile_diffusion/modeling_text_unet.py encoder_attention_mask docs

Co-authored-by: Pedro Cuenca <[email protected]>

transformer_2d encoder_attention_mask docs

Co-authored-by: Pedro Cuenca <[email protected]>

unet_2d_blocks.py: add parameter name comments

Co-authored-by: Pedro Cuenca <[email protected]>

revert description. bool-to-bias treatment happens in unet_2d_condition only.

comment parameter names

fix copies, style

* encoder_attention_mask for SimpleCrossAttnDownBlock2D, SimpleCrossAttnUpBlock2D

* encoder_attention_mask for UNetMidBlock2DSimpleCrossAttn

* support attention_mask, encoder_attention_mask in KCrossAttnDownBlock2D, KCrossAttnUpBlock2D, KAttentionBlock. fix binding of attention_mask, cross_attention_kwargs params in KCrossAttnDownBlock2D, KCrossAttnUpBlock2D checkpoint invocations.

* fix mistake made during merge conflict resolution

* regenerate versatile_diffusion

* pass time embedding into checkpointed attention invocation

* always assume encoder_attention_mask is a mask (i.e. not a bias).

* style, fix-copies

* add tests for cross-attention masks

* add test for padding of attention mask

* explain mask's query_tokens dim. fix explanation about broadcasting over channels; we actually broadcast over query tokens

* support both masks and biases in Transformer2DModel#forward. document behaviour

* fix-copies

* delete attention_mask docs on the basis I never tested self-attention masking myself. not comfortable explaining it, since I don't actually understand how a self-attn mask can work in its current form: the key length will be different in every ResBlock (we don't downsample the mask when we downsample the image).

* review feedback: the standard Unet blocks shouldn't pass temb to attn (only to resnet). remove from KCrossAttnDownBlock2D,KCrossAttnUpBlock2D#forward.

* remove encoder_attention_mask param from SimpleCrossAttn{Up,Down}Block2D,UNetMidBlock2DSimpleCrossAttn, and mask-choice in those blocks' #forward, on the basis that they only do one type of attention, so the consumer can pass whichever type of attention_mask is appropriate.

* put attention mask padding back to how it was (since the SD use-case it enabled wasn't important, and it breaks the original unclip use-case). disable the test which was added.

* fix-copies

* style

* fix-copies

* put encoder_attention_mask param back into Simple block forward interfaces, to ensure consistency of forward interface.

* restore passing of emb to KAttentionBlock#forward, on the basis that removal caused test failures. restore also the passing of emb to checkpointed calls to KAttentionBlock#forward.

* make simple unet2d blocks use encoder_attention_mask, but only when attention_mask is None. this should fix UnCLIP compatibility.

* fix copies
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* Cross-attention masks

prefer qualified symbol, fix accidental Optional

prefer qualified symbol in AttentionProcessor

prefer qualified symbol in embeddings.py

qualified symbol in transformed_2d

qualify FloatTensor in unet_2d_blocks

move new transformer_2d params attention_mask, encoder_attention_mask to the end of the section which is assumed (e.g. by functions such as checkpoint()) to have a stable positional param interface. regard return_dict as a special-case which is assumed to be injected separately from positional params (e.g. by create_custom_forward()).

move new encoder_attention_mask param to end of CrossAttn block interfaces and Unet2DCondition interface, to maintain positional param interface.

regenerate modeling_text_unet.py

remove unused import

unet_2d_condition encoder_attention_mask docs

Co-authored-by: Pedro Cuenca <[email protected]>

versatile_diffusion/modeling_text_unet.py encoder_attention_mask docs

Co-authored-by: Pedro Cuenca <[email protected]>

transformer_2d encoder_attention_mask docs

Co-authored-by: Pedro Cuenca <[email protected]>

unet_2d_blocks.py: add parameter name comments

Co-authored-by: Pedro Cuenca <[email protected]>

revert description. bool-to-bias treatment happens in unet_2d_condition only.

comment parameter names

fix copies, style

* encoder_attention_mask for SimpleCrossAttnDownBlock2D, SimpleCrossAttnUpBlock2D

* encoder_attention_mask for UNetMidBlock2DSimpleCrossAttn

* support attention_mask, encoder_attention_mask in KCrossAttnDownBlock2D, KCrossAttnUpBlock2D, KAttentionBlock. fix binding of attention_mask, cross_attention_kwargs params in KCrossAttnDownBlock2D, KCrossAttnUpBlock2D checkpoint invocations.

* fix mistake made during merge conflict resolution

* regenerate versatile_diffusion

* pass time embedding into checkpointed attention invocation

* always assume encoder_attention_mask is a mask (i.e. not a bias).

* style, fix-copies

* add tests for cross-attention masks

* add test for padding of attention mask

* explain mask's query_tokens dim. fix explanation about broadcasting over channels; we actually broadcast over query tokens

* support both masks and biases in Transformer2DModel#forward. document behaviour

* fix-copies

* delete attention_mask docs on the basis I never tested self-attention masking myself. not comfortable explaining it, since I don't actually understand how a self-attn mask can work in its current form: the key length will be different in every ResBlock (we don't downsample the mask when we downsample the image).

* review feedback: the standard Unet blocks shouldn't pass temb to attn (only to resnet). remove from KCrossAttnDownBlock2D,KCrossAttnUpBlock2D#forward.

* remove encoder_attention_mask param from SimpleCrossAttn{Up,Down}Block2D,UNetMidBlock2DSimpleCrossAttn, and mask-choice in those blocks' #forward, on the basis that they only do one type of attention, so the consumer can pass whichever type of attention_mask is appropriate.

* put attention mask padding back to how it was (since the SD use-case it enabled wasn't important, and it breaks the original unclip use-case). disable the test which was added.

* fix-copies

* style

* fix-copies

* put encoder_attention_mask param back into Simple block forward interfaces, to ensure consistency of forward interface.

* restore passing of emb to KAttentionBlock#forward, on the basis that removal caused test failures. restore also the passing of emb to checkpointed calls to KAttentionBlock#forward.

* make simple unet2d blocks use encoder_attention_mask, but only when attention_mask is None. this should fix UnCLIP compatibility.

* fix copies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Does attention masking actually work?
8 participants