Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 9aea4a2

Browse files
committedMar 11, 2023
support cross-attention masking for SlicedAttnProcessor
1 parent ad7484b commit 9aea4a2

File tree

1 file changed

+38
-16
lines changed

1 file changed

+38
-16
lines changed
 

Diff for: ‎src/diffusers/models/cross_attention.py

+38-16
Original file line numberDiff line numberDiff line change
@@ -622,31 +622,53 @@ class SlicedAttnProcessor:
622622
def __init__(self, slice_size):
623623
self.slice_size = slice_size
624624

625-
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
626-
batch_size, sequence_length, _ = hidden_states.shape
625+
def __call__(
626+
self,
627+
attn: CrossAttention,
628+
hidden_states: FloatTensor,
629+
encoder_hidden_states: Optional[FloatTensor] = None,
630+
attention_mask: Optional[FloatTensor] = None,
631+
encoder_attention_bias: Optional[FloatTensor] = None,
632+
):
633+
if encoder_hidden_states is None:
634+
encoder_hidden_states = hidden_states
635+
else:
636+
if encoder_attention_bias is not None:
637+
if attention_mask is not None:
638+
# it's not well-defined whether `attention_mask` should be passed to self-attention, cross-attention, neither* or both.
639+
# if two sources of bias (`attention_mask`, `encoder_attention_bias`) are provided: it's likely to be a mistake.
640+
raise ValueError(f"two attention biases have been supplied: `attention_mask` and `encoder_attention_bias`. expected a maximum of one source of bias.")
641+
attention_mask = encoder_attention_bias
642+
# make broadcastable over query tokens
643+
# TODO: see if there's a satisfactory way to unify how the `attention_mask`/`encoder_attention_bias` code paths
644+
# create this singleton dim. the way AttnProcessor2_0 does it could work.
645+
# here I'm trying to avoid interfering with the original `attention_mask` code path,
646+
# by limiting the unsqueeze() to just the `encoder_attention_bias` path, on the basis that
647+
# `attention_mask` is already working without this change.
648+
# maybe it's because UNet2DConditionModel#forward unsqueeze()s `attention_mask` earlier.
649+
attention_mask = attention_mask.unsqueeze(-2)
650+
if attn.cross_attention_norm:
651+
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
627652

628-
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
653+
batch_size, key_tokens, _ = encoder_hidden_states.shape
654+
attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
629655

630656
query = attn.to_q(hidden_states)
631-
dim = query.shape[-1]
632657
query = attn.head_to_batch_dim(query)
633658

634-
if encoder_hidden_states is None:
635-
encoder_hidden_states = hidden_states
636-
elif attn.cross_attention_norm:
637-
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
638-
639659
key = attn.to_k(encoder_hidden_states)
640660
value = attn.to_v(encoder_hidden_states)
641661
key = attn.head_to_batch_dim(key)
642662
value = attn.head_to_batch_dim(value)
643663

644-
batch_size_attention = query.shape[0]
664+
batch_x_heads, query_tokens, _ = query.shape
665+
inner_dim = attn.to_q.out_features
666+
channels_per_head = inner_dim // attn.heads
645667
hidden_states = torch.zeros(
646-
(batch_size_attention, sequence_length, dim // attn.heads), device=query.device, dtype=query.dtype
668+
(batch_x_heads, query_tokens, channels_per_head), device=query.device, dtype=query.dtype
647669
)
648670

649-
for i in range(hidden_states.shape[0] // self.slice_size):
671+
for i in range(batch_x_heads // self.slice_size):
650672
start_idx = i * self.slice_size
651673
end_idx = (i + 1) * self.slice_size
652674

@@ -662,10 +684,10 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No
662684

663685
hidden_states = attn.batch_to_head_dim(hidden_states)
664686

665-
# linear proj
666-
hidden_states = attn.to_out[0](hidden_states)
667-
# dropout
668-
hidden_states = attn.to_out[1](hidden_states)
687+
linear_proj, dropout = attn.to_out
688+
689+
hidden_states = linear_proj(hidden_states)
690+
hidden_states = dropout(hidden_states)
669691

670692
return hidden_states
671693

0 commit comments

Comments
 (0)
Please sign in to comment.