diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 6e1dc1037c20..9d5f4a75ee1c 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -731,6 +731,43 @@ def prepare_attention_mask( return attention_mask + def prepare_joint_attention_mask( + self, attention_mask: torch.Tensor, target_length: int, dtype: torch.dtype + ) -> torch.Tensor: + if attention_mask is None: + return attention_mask + + current_length: int = attention_mask.shape[-1] + remaining_length: int = target_length - current_length + if current_length != target_length: + if attention_mask.device.type == "mps": + # HACK: MPS: Does not support padding by greater than dimension of input tensor. + # Instead, we can manually construct the padding tensor. + padding_shape = (attention_mask.shape[0], remaining_length) + padding = torch.ones(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device) + attention_mask = torch.cat([padding, attention_mask], dim=2) + else: + attention_mask = F.pad(attention_mask, (remaining_length, 0), value=1.0) + + if attention_mask.dim() == 3: + # If provided attention mask has shape [batch_size, target_seq_length, src_seq_length], + # we only need to broadcast it to all the heads + attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # If provided attention mask has shape [batch_size, seq_length], + # we boardcast both the heads and the target sequences, + # there is no need to mask all the lines for target padding token as it would not affect other non-padding tokens + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = (1.0 - attention_mask) * torch.finfo(dtype).min + + return attention_mask + def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor: r""" Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the @@ -1458,7 +1495,14 @@ def __call__( key = torch.cat([key, encoder_hidden_states_key_proj], dim=2) value = torch.cat([value, encoder_hidden_states_value_proj], dim=2) - hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + if attention_mask is not None: + attention_mask = attn.prepare_joint_attention_mask(attention_mask, key.shape[2], key.dtype) + else: + attention_mask = None + + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) @@ -2312,6 +2356,11 @@ def __call__( key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + if attention_mask is not None: + attention_mask = attn.prepare_joint_attention_mask(attention_mask, key.shape[2], key.dtype) + else: + attention_mask = None + if image_rotary_emb is not None: from .embeddings import apply_rotary_emb