Skip to content

Commit ad7484b

Browse files
committed
support cross-attention masking for XFormersCrossAttnProcessor
1 parent c864345 commit ad7484b

File tree

1 file changed

+48
-17
lines changed

1 file changed

+48
-17
lines changed

Diff for: src/diffusers/models/cross_attention.py

+48-17
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Callable, Optional, Union
14+
from typing import Callable, Optional, Union, Dict, Any
1515

1616
import torch
1717
import torch.nn.functional as F
@@ -198,7 +198,13 @@ def set_processor(self, processor: "AttnProcessor"):
198198

199199
self.processor = processor
200200

201-
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
201+
def forward(
202+
self,
203+
hidden_states: FloatTensor,
204+
encoder_hidden_states: Optional[FloatTensor] = None,
205+
attention_mask: Optional[FloatTensor] = None,
206+
**cross_attention_kwargs: Dict[str, Any]
207+
):
202208
# The `CrossAttention` class can call different attention processors / attention functions
203209
# here we simply pass along all tensors to the selected processor class
204210
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
@@ -313,8 +319,12 @@ def __call__(
313319
raise ValueError(f"two attention biases have been supplied: `attention_mask` and `encoder_attention_bias`. expected a maximum of one source of bias.")
314320
attention_mask = encoder_attention_bias
315321
# make broadcastable over query tokens
316-
# TODO: consider aligning implementations such that AttnProcessor2_0 and CrossAttnProcessor do unsqueeze
317-
# in the same way/circumstances -- AttnProcessor2_0 does it for `attention_mask` **and** for `encoder_attention_bias`.
322+
# TODO: see if there's a satisfactory way to unify how the `attention_mask`/`encoder_attention_bias` code paths
323+
# create this singleton dim. the way AttnProcessor2_0 does it could work.
324+
# here I'm trying to avoid interfering with the original `attention_mask` code path,
325+
# by limiting the unsqueeze() to just the `encoder_attention_bias` path, on the basis that
326+
# `attention_mask` is already working without this change.
327+
# maybe it's because UNet2DConditionModel#forward unsqueeze()s `attention_mask` earlier.
318328
attention_mask = attention_mask.unsqueeze(-2)
319329
if attn.cross_attention_norm:
320330
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
@@ -453,18 +463,39 @@ class XFormersCrossAttnProcessor:
453463
def __init__(self, attention_op: Optional[Callable] = None):
454464
self.attention_op = attention_op
455465

456-
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
457-
batch_size, sequence_length, _ = hidden_states.shape
458-
459-
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
460-
461-
query = attn.to_q(hidden_states)
462-
466+
def __call__(
467+
self,
468+
attn: CrossAttention,
469+
hidden_states: FloatTensor,
470+
encoder_hidden_states: Optional[FloatTensor] = None,
471+
attention_mask: Optional[FloatTensor] = None,
472+
encoder_attention_bias: Optional[FloatTensor] = None,
473+
):
463474
if encoder_hidden_states is None:
464475
encoder_hidden_states = hidden_states
465-
elif attn.cross_attention_norm:
466-
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
476+
else:
477+
if encoder_attention_bias is not None:
478+
if attention_mask is not None:
479+
# it's not well-defined whether `attention_mask` should be passed to self-attention, cross-attention, neither* or both.
480+
# if two sources of bias (`attention_mask`, `encoder_attention_bias`) are provided: it's likely to be a mistake.
481+
raise ValueError(f"two attention biases have been supplied: `attention_mask` and `encoder_attention_bias`. expected a maximum of one source of bias.")
482+
attention_mask = encoder_attention_bias
483+
484+
# TODO: figure out why the original `attention_mask` code path didn't attempt broadcasting over query tokens.
485+
# it feels like this logic would be needed in that code path too.
467486

487+
# make broadcastable over query tokens
488+
attention_mask = attention_mask.unsqueeze(-2)
489+
_, query_tokens, _ = hidden_states.shape
490+
# xformers doesn't broadcast for us, so we expand our singleton dimension manually
491+
attention_mask = attention_mask.expand(-1, query_tokens, -1)
492+
if attn.cross_attention_norm:
493+
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
494+
495+
batch_size, key_tokens, _ = encoder_hidden_states.shape
496+
attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
497+
498+
query = attn.to_q(hidden_states)
468499
key = attn.to_k(encoder_hidden_states)
469500
value = attn.to_v(encoder_hidden_states)
470501

@@ -478,10 +509,10 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No
478509
hidden_states = hidden_states.to(query.dtype)
479510
hidden_states = attn.batch_to_head_dim(hidden_states)
480511

481-
# linear proj
482-
hidden_states = attn.to_out[0](hidden_states)
483-
# dropout
484-
hidden_states = attn.to_out[1](hidden_states)
512+
linear_proj, dropout = attn.to_out
513+
514+
hidden_states = linear_proj(hidden_states)
515+
hidden_states = dropout(hidden_states)
485516
return hidden_states
486517

487518

0 commit comments

Comments
 (0)