Skip to content

Support for cross-attention bias / mask #2634

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
May 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
57bdc4f
Cross-attention masks
Birch-san Mar 29, 2023
bd763a4
encoder_attention_mask for SimpleCrossAttnDownBlock2D, SimpleCrossAtt…
Birch-san May 10, 2023
d2f99b9
encoder_attention_mask for UNetMidBlock2DSimpleCrossAttn
Birch-san May 10, 2023
9ad3ed9
support attention_mask, encoder_attention_mask in KCrossAttnDownBlock…
Birch-san May 10, 2023
fecc595
fix mistake made during merge conflict resolution
Birch-san May 11, 2023
370daf5
regenerate versatile_diffusion
Birch-san May 11, 2023
b60b7ed
pass time embedding into checkpointed attention invocation
Birch-san May 11, 2023
faef7ac
always assume encoder_attention_mask is a mask (i.e. not a bias).
Birch-san May 11, 2023
725c27a
style, fix-copies
Birch-san May 11, 2023
e0437ae
add tests for cross-attention masks
Birch-san May 11, 2023
1a68b65
add test for padding of attention mask
Birch-san May 12, 2023
ea2948e
explain mask's query_tokens dim. fix explanation about broadcasting o…
Birch-san May 12, 2023
786249e
support both masks and biases in Transformer2DModel#forward. document…
Birch-san May 12, 2023
23440ef
fix-copies
Birch-san May 12, 2023
9155215
delete attention_mask docs on the basis I never tested self-attention…
Birch-san May 12, 2023
479574e
review feedback: the standard Unet blocks shouldn't pass temb to attn…
Birch-san May 18, 2023
7e9ba8f
remove encoder_attention_mask param from SimpleCrossAttn{Up,Down}Bloc…
Birch-san May 18, 2023
ba3da64
put attention mask padding back to how it was (since the SD use-case …
Birch-san May 18, 2023
0c020cf
fix-copies
Birch-san May 18, 2023
3ed583f
style
Birch-san May 18, 2023
c4931a0
fix-copies
Birch-san May 18, 2023
0d78e8b
put encoder_attention_mask param back into Simple block forward inter…
Birch-san May 19, 2023
937ab70
restore passing of emb to KAttentionBlock#forward, on the basis that …
Birch-san May 20, 2023
83762b2
make simple unet2d blocks use encoder_attention_mask, but only when a…
Birch-san May 20, 2023
8c323c5
fix copies
Birch-san May 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from typing import Any, Dict, Optional

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -120,13 +120,13 @@ def __init__(

def forward(
self,
hidden_states,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
timestep=None,
cross_attention_kwargs=None,
class_labels=None,
hidden_states: torch.FloatTensor,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
timestep: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
class_labels: Optional[torch.LongTensor] = None,
):
# Notice that normalization is always applied before the real computation in the following blocks.
# 1. Self-Attention
Expand Down Expand Up @@ -155,8 +155,6 @@ def forward(
norm_hidden_states = (
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
)
# TODO (Birch-San): Here we should prepare the encoder_attention mask correctly
# prepare attention mask here

attn_output = self.attn2(
norm_hidden_states,
Expand Down
33 changes: 29 additions & 4 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,14 +373,24 @@ def prepare_attention_mask(self, attention_mask, target_length, batch_size=None,
if attention_mask is None:
return attention_mask

if attention_mask.shape[-1] != target_length:
current_length: int = attention_mask.shape[-1]
if current_length > target_length:
# we *could* trim the mask with:
# attention_mask = attention_mask[:,:target_length]
# but this is weird enough that it's more likely to be a mistake than a shortcut
raise ValueError(f"mask's length ({current_length}) exceeds the sequence length ({target_length}).")
elif current_length < target_length:
if attention_mask.device.type == "mps":
# HACK: MPS: Does not support padding by greater than dimension of input tensor.
# Instead, we can manually construct the padding tensor.
padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
attention_mask = torch.cat([attention_mask, padding], dim=2)
else:
# TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
# we want to instead pad by (0, remaining_length), where remaining_length is:
# remaining_length: int = target_length - current_length
# TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)

if out_dim == 3:
Expand Down Expand Up @@ -813,7 +823,13 @@ class XFormersAttnProcessor:
def __init__(self, attention_op: Optional[Callable] = None):
self.attention_op = attention_op

def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
):
residual = hidden_states

input_ndim = hidden_states.ndim
Expand All @@ -822,11 +838,20 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

batch_size, sequence_length, _ = (
batch_size, key_tokens, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)

attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
if attention_mask is not None:
# expand our mask's singleton query_tokens dimension:
# [batch*heads, 1, key_tokens] ->
# [batch*heads, query_tokens, key_tokens]
# so that it can be added as a bias onto the attention scores that xformers computes:
# [batch*heads, query_tokens, key_tokens]
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
_, query_tokens, _ = hidden_states.shape
attention_mask = attention_mask.expand(-1, query_tokens, -1)

if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def token_drop(self, labels, force_drop_ids=None):
labels = torch.where(drop_ids, self.num_classes, labels)
return labels

def forward(self, labels, force_drop_ids=None):
def forward(self, labels: torch.LongTensor, force_drop_ids=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hey @Birch-san PR's great :) just in the future it'd be super helpful to not add type annotations if the argument isn't a part of the PR. No issues with adding type annotations you feel are important in a separate PR. It just helps us to be able to quickly look at a diff of the function definition and see what arguments were added/removed. As is, we have to manually look at each argument to tell if it was just an added type annotation vs an actual change to the argument. No biggie :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will try to be more restrained in future

use_dropout = self.dropout_prob > 0
if (self.training and use_dropout) or (force_drop_ids is not None):
labels = self.token_drop(labels, force_drop_ids)
Expand Down
47 changes: 40 additions & 7 deletions src/diffusers/models/transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional
from typing import Any, Dict, Optional

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -213,11 +213,13 @@ def __init__(

def forward(
self,
hidden_states,
encoder_hidden_states=None,
timestep=None,
class_labels=None,
cross_attention_kwargs=None,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
timestep: Optional[torch.LongTensor] = None,
class_labels: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
return_dict: bool = True,
):
"""
Expand All @@ -228,11 +230,17 @@ def forward(
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention.
timestep ( `torch.long`, *optional*):
timestep ( `torch.LongTensor`, *optional*):
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels
conditioning.
encoder_attention_mask ( `torch.Tensor`, *optional* ).
Cross-attention mask, applied to encoder_hidden_states. Two formats supported:
Mask `(batch, sequence_length)` True = keep, False = discard. Bias `(batch, 1, sequence_length)` 0
= keep, -10000 = discard.
If ndim == 2: will be interpreted as a mask, then converted into a bias consistent with the format
above. This bias will be added to the cross-attention scores.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.

Expand All @@ -241,6 +249,29 @@ def forward(
[`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
"""
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great!

# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
# expects mask of shape:
# [batch, key_tokens]
# adds singleton query_tokens dimension:
# [batch, 1, key_tokens]
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
if attention_mask is not None and attention_mask.ndim == 2:
# assume that mask is expressed as:
# (1 = keep, 0 = discard)
# convert mask into a bias that can be added to attention scores:
# (keep = +0, discard = -10000.0)
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)

# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)

# 1. Input
if self.is_input_continuous:
batch, _, height, width = hidden_states.shape
Expand All @@ -264,7 +295,9 @@ def forward(
for block in self.transformer_blocks:
hidden_states = block(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
timestep=timestep,
cross_attention_kwargs=cross_attention_kwargs,
class_labels=class_labels,
Expand Down
Loading