diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 0b313b83d360..a7a9a472d9e9 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Any, Dict, Optional import torch import torch.nn.functional as F @@ -120,13 +120,13 @@ def __init__( def forward( self, - hidden_states, - attention_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - timestep=None, - cross_attention_kwargs=None, - class_labels=None, + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, ): # Notice that normalization is always applied before the real computation in the following blocks. # 1. Self-Attention @@ -155,8 +155,6 @@ def forward( norm_hidden_states = ( self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) ) - # TODO (Birch-San): Here we should prepare the encoder_attention mask correctly - # prepare attention mask here attn_output = self.attn2( norm_hidden_states, diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index a489814c4787..8bb3a6f2ad8f 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -373,7 +373,13 @@ def prepare_attention_mask(self, attention_mask, target_length, batch_size=None, if attention_mask is None: return attention_mask - if attention_mask.shape[-1] != target_length: + current_length: int = attention_mask.shape[-1] + if current_length > target_length: + # we *could* trim the mask with: + # attention_mask = attention_mask[:,:target_length] + # but this is weird enough that it's more likely to be a mistake than a shortcut + raise ValueError(f"mask's length ({current_length}) exceeds the sequence length ({target_length}).") + elif current_length < target_length: if attention_mask.device.type == "mps": # HACK: MPS: Does not support padding by greater than dimension of input tensor. # Instead, we can manually construct the padding tensor. @@ -381,6 +387,10 @@ def prepare_attention_mask(self, attention_mask, target_length, batch_size=None, padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device) attention_mask = torch.cat([attention_mask, padding], dim=2) else: + # TODO: for pipelines such as stable-diffusion, padding cross-attn mask: + # we want to instead pad by (0, remaining_length), where remaining_length is: + # remaining_length: int = target_length - current_length + # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) if out_dim == 3: @@ -813,7 +823,13 @@ class XFormersAttnProcessor: def __init__(self, attention_op: Optional[Callable] = None): self.attention_op = attention_op - def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + ): residual = hidden_states input_ndim = hidden_states.ndim @@ -822,11 +838,20 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - batch_size, sequence_length, _ = ( + batch_size, key_tokens, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size) + if attention_mask is not None: + # expand our mask's singleton query_tokens dimension: + # [batch*heads, 1, key_tokens] -> + # [batch*heads, query_tokens, key_tokens] + # so that it can be added as a bias onto the attention scores that xformers computes: + # [batch*heads, query_tokens, key_tokens] + # we do this explicitly because xformers doesn't broadcast the singleton dimension for us. + _, query_tokens, _ = hidden_states.shape + attention_mask = attention_mask.expand(-1, query_tokens, -1) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index fa88bce305e6..fb803039b268 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -352,7 +352,7 @@ def token_drop(self, labels, force_drop_ids=None): labels = torch.where(drop_ids, self.num_classes, labels) return labels - def forward(self, labels, force_drop_ids=None): + def forward(self, labels: torch.LongTensor, force_drop_ids=None): use_dropout = self.dropout_prob > 0 if (self.training and use_dropout) or (force_drop_ids is not None): labels = self.token_drop(labels, force_drop_ids) diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py index fde1014bd2e7..ec4cb371845f 100644 --- a/src/diffusers/models/transformer_2d.py +++ b/src/diffusers/models/transformer_2d.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass -from typing import Optional +from typing import Any, Dict, Optional import torch import torch.nn.functional as F @@ -213,11 +213,13 @@ def __init__( def forward( self, - hidden_states, - encoder_hidden_states=None, - timestep=None, - class_labels=None, - cross_attention_kwargs=None, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + class_labels: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, return_dict: bool = True, ): """ @@ -228,11 +230,17 @@ def forward( encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): Conditional embeddings for cross attention layer. If not given, cross-attention defaults to self-attention. - timestep ( `torch.long`, *optional*): + timestep ( `torch.LongTensor`, *optional*): Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels conditioning. + encoder_attention_mask ( `torch.Tensor`, *optional* ). + Cross-attention mask, applied to encoder_hidden_states. Two formats supported: + Mask `(batch, sequence_length)` True = keep, False = discard. Bias `(batch, 1, sequence_length)` 0 + = keep, -10000 = discard. + If ndim == 2: will be interpreted as a mask, then converted into a bias consistent with the format + above. This bias will be added to the cross-attention scores. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. @@ -241,6 +249,29 @@ def forward( [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None and attention_mask.ndim == 2: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + # 1. Input if self.is_input_continuous: batch, _, height, width = hidden_states.shape @@ -264,7 +295,9 @@ def forward( for block in self.transformer_blocks: hidden_states = block( hidden_states, + attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, timestep=timestep, cross_attention_kwargs=cross_attention_kwargs, class_labels=class_labels, diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 75d9eb3e03df..6f8e3d0f5500 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Any, Dict, Optional, Tuple import numpy as np import torch @@ -558,14 +558,22 @@ def __init__( self.resnets = nn.ModuleList(resnets) def forward( - self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None - ): + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] hidden_states = resnet(hidden_states, temb) @@ -659,16 +667,34 @@ def __init__( self.resnets = nn.ModuleList(resnets) def forward( - self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, ): cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + + if attention_mask is None: + # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. + mask = None if encoder_hidden_states is None else encoder_attention_mask + else: + # when attention_mask is defined: we don't even check for encoder_attention_mask. + # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks. + # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask. + # then we can simplify this whole if/else block to: + # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask + mask = attention_mask + hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): # attn hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, + attention_mask=mask, **cross_attention_kwargs, ) @@ -850,9 +876,14 @@ def __init__( self.gradient_checkpointing = False def forward( - self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, ): - # TODO(Patrick, William) - attention mask is not used output_states = () for resnet, attn in zip(self.resnets, self.attentions): @@ -867,33 +898,32 @@ def custom_forward(*inputs): return custom_forward - if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, use_reentrant=False - ) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), - hidden_states, - encoder_hidden_states, - cross_attention_kwargs, - use_reentrant=False, - )[0] - else: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), - hidden_states, - encoder_hidden_states, - cross_attention_kwargs, - )[0] + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + None, # timestep + None, # class_labels + cross_attention_kwargs, + attention_mask, + encoder_attention_mask, + **ckpt_kwargs, + )[0] else: hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] @@ -1501,11 +1531,28 @@ def __init__( self.gradient_checkpointing = False def forward( - self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, ): output_states = () cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + if attention_mask is None: + # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. + mask = None if encoder_hidden_states is None else encoder_attention_mask + else: + # when attention_mask is defined: we don't even check for encoder_attention_mask. + # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks. + # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask. + # then we can simplify this whole if/else block to: + # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask + mask = attention_mask + for resnet, attn in zip(self.resnets, self.attentions): if self.training and self.gradient_checkpointing: @@ -1523,6 +1570,7 @@ def custom_forward(*inputs): create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, + mask, cross_attention_kwargs, )[0] else: @@ -1531,7 +1579,7 @@ def custom_forward(*inputs): hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, + attention_mask=mask, **cross_attention_kwargs, ) @@ -1690,7 +1738,13 @@ def __init__( self.gradient_checkpointing = False def forward( - self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, ): output_states = () @@ -1706,29 +1760,23 @@ def custom_forward(*inputs): return custom_forward - if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, use_reentrant=False - ) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), - hidden_states, - encoder_hidden_states, - attention_mask, - cross_attention_kwargs, - use_reentrant=False, - ) - else: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), - hidden_states, - encoder_hidden_states, - attention_mask, - cross_attention_kwargs, - ) + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + temb, + attention_mask, + cross_attention_kwargs, + encoder_attention_mask, + **ckpt_kwargs, + ) else: hidden_states = resnet(hidden_states, temb) hidden_states = attn( @@ -1737,6 +1785,7 @@ def custom_forward(*inputs): emb=temb, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, ) if self.downsamplers is None: @@ -1916,15 +1965,15 @@ def __init__( def forward( self, - hidden_states, - res_hidden_states_tuple, - temb=None, - encoder_hidden_states=None, - cross_attention_kwargs=None, - upsample_size=None, - attention_mask=None, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, ): - # TODO(Patrick, William) - attention mask is not used for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] @@ -1942,33 +1991,32 @@ def custom_forward(*inputs): return custom_forward - if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, use_reentrant=False - ) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), - hidden_states, - encoder_hidden_states, - cross_attention_kwargs, - use_reentrant=False, - )[0] - else: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), - hidden_states, - encoder_hidden_states, - cross_attention_kwargs, - )[0] + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + None, # timestep + None, # class_labels + cross_attention_kwargs, + attention_mask, + encoder_attention_mask, + **ckpt_kwargs, + )[0] else: hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] @@ -2594,15 +2642,28 @@ def __init__( def forward( self, - hidden_states, - res_hidden_states_tuple, - temb=None, - encoder_hidden_states=None, - upsample_size=None, - attention_mask=None, - cross_attention_kwargs=None, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, ): cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + + if attention_mask is None: + # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. + mask = None if encoder_hidden_states is None else encoder_attention_mask + else: + # when attention_mask is defined: we don't even check for encoder_attention_mask. + # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks. + # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask. + # then we can simplify this whole if/else block to: + # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask + mask = attention_mask + for resnet, attn in zip(self.resnets, self.attentions): # resnet # pop res hidden states @@ -2626,6 +2687,7 @@ def custom_forward(*inputs): create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, + mask, cross_attention_kwargs, )[0] else: @@ -2634,7 +2696,7 @@ def custom_forward(*inputs): hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, + attention_mask=mask, **cross_attention_kwargs, ) @@ -2811,13 +2873,14 @@ def __init__( def forward( self, - hidden_states, - res_hidden_states_tuple, - temb=None, - encoder_hidden_states=None, - cross_attention_kwargs=None, - upsample_size=None, - attention_mask=None, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, ): res_hidden_states_tuple = res_hidden_states_tuple[-1] if res_hidden_states_tuple is not None: @@ -2835,29 +2898,23 @@ def custom_forward(*inputs): return custom_forward - if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, use_reentrant=False - ) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), - hidden_states, - encoder_hidden_states, - attention_mask, - cross_attention_kwargs, - use_reentrant=False, - )[0] - else: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), - hidden_states, - encoder_hidden_states, - attention_mask, - cross_attention_kwargs, - )[0] + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + temb, + attention_mask, + cross_attention_kwargs, + encoder_attention_mask, + **ckpt_kwargs, + )[0] else: hidden_states = resnet(hidden_states, temb) hidden_states = attn( @@ -2866,6 +2923,7 @@ def custom_forward(*inputs): emb=temb, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, ) if self.upsamplers is not None: @@ -2944,11 +3002,14 @@ def _to_4d(self, hidden_states, height, weight): def forward( self, - hidden_states, - encoder_hidden_states=None, - emb=None, - attention_mask=None, - cross_attention_kwargs=None, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + # TODO: mark emb as non-optional (self.norm2 requires it). + # requires assessing impact of change to positional param interface. + emb: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, ): cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} @@ -2962,6 +3023,7 @@ def forward( attn_output = self.attn1( norm_hidden_states, encoder_hidden_states=None, + attention_mask=attention_mask, **cross_attention_kwargs, ) attn_output = self._to_4d(attn_output, height, weight) @@ -2976,6 +3038,7 @@ def forward( attn_output = self.attn2( norm_hidden_states, encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask if encoder_hidden_states is None else encoder_attention_mask, **cross_attention_kwargs, ) attn_output = self._to_4d(attn_output, height, weight) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 2a4c9fd72c1b..76a40ffa1ec5 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -618,6 +618,7 @@ def forward( cross_attention_kwargs: Optional[Dict[str, Any]] = None, down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, mid_block_additional_residual: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[UNet2DConditionOutput, Tuple]: r""" @@ -625,6 +626,10 @@ def forward( sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + encoder_attention_mask (`torch.Tensor`): + (batch, sequence_length) cross-attention mask, applied to encoder_hidden_states. True = keep, False = + discard. Mask will be converted into a bias, which adds large negative values to attention scores + corresponding to "discard" tokens. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): @@ -651,11 +656,27 @@ def forward( logger.info("Forward upsample size to force interpolation output size.") forward_upsample_size = True - # prepare attention_mask + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) if attention_mask is not None: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None: + encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + # 0. center input if necessary if self.config.center_input_sample: sample = 2 * sample - 1.0 @@ -727,6 +748,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb) @@ -752,6 +774,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, ) if mid_block_additional_residual is not None: @@ -778,6 +801,7 @@ def forward( cross_attention_kwargs=cross_attention_kwargs, upsample_size=upsample_size, attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, ) else: sample = upsample_block( diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 7aaa0e49e1da..29cde43337d2 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -721,6 +721,7 @@ def forward( cross_attention_kwargs: Optional[Dict[str, Any]] = None, down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, mid_block_additional_residual: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[UNet2DConditionOutput, Tuple]: r""" @@ -728,6 +729,10 @@ def forward( sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + encoder_attention_mask (`torch.Tensor`): + (batch, sequence_length) cross-attention mask, applied to encoder_hidden_states. True = keep, False = + discard. Mask will be converted into a bias, which adds large negative values to attention scores + corresponding to "discard" tokens. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): @@ -754,11 +759,27 @@ def forward( logger.info("Forward upsample size to force interpolation output size.") forward_upsample_size = True - # prepare attention_mask + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) if attention_mask is not None: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None: + encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + # 0. center input if necessary if self.config.center_input_sample: sample = 2 * sample - 1.0 @@ -830,6 +851,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb) @@ -855,6 +877,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, ) if mid_block_additional_residual is not None: @@ -881,6 +904,7 @@ def forward( cross_attention_kwargs=cross_attention_kwargs, upsample_size=upsample_size, attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, ) else: sample = upsample_block( @@ -1188,9 +1212,14 @@ def __init__( self.gradient_checkpointing = False def forward( - self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, ): - # TODO(Patrick, William) - attention mask is not used output_states = () for resnet, attn in zip(self.resnets, self.attentions): @@ -1205,33 +1234,32 @@ def custom_forward(*inputs): return custom_forward - if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, use_reentrant=False - ) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), - hidden_states, - encoder_hidden_states, - cross_attention_kwargs, - use_reentrant=False, - )[0] - else: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), - hidden_states, - encoder_hidden_states, - cross_attention_kwargs, - )[0] + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + None, # timestep + None, # class_labels + cross_attention_kwargs, + attention_mask, + encoder_attention_mask, + **ckpt_kwargs, + )[0] else: hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] @@ -1414,15 +1442,15 @@ def __init__( def forward( self, - hidden_states, - res_hidden_states_tuple, - temb=None, - encoder_hidden_states=None, - cross_attention_kwargs=None, - upsample_size=None, - attention_mask=None, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, ): - # TODO(Patrick, William) - attention mask is not used for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] @@ -1440,33 +1468,32 @@ def custom_forward(*inputs): return custom_forward - if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, use_reentrant=False - ) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), - hidden_states, - encoder_hidden_states, - cross_attention_kwargs, - use_reentrant=False, - )[0] - else: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), - hidden_states, - encoder_hidden_states, - cross_attention_kwargs, - )[0] + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + None, # timestep + None, # class_labels + cross_attention_kwargs, + attention_mask, + encoder_attention_mask, + **ckpt_kwargs, + )[0] else: hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] @@ -1564,14 +1591,22 @@ def __init__( self.resnets = nn.ModuleList(resnets) def forward( - self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None - ): + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] hidden_states = resnet(hidden_states, temb) @@ -1666,16 +1701,34 @@ def __init__( self.resnets = nn.ModuleList(resnets) def forward( - self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, ): cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + + if attention_mask is None: + # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. + mask = None if encoder_hidden_states is None else encoder_attention_mask + else: + # when attention_mask is defined: we don't even check for encoder_attention_mask. + # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks. + # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask. + # then we can simplify this whole if/else block to: + # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask + mask = attention_mask + hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): # attn hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, + attention_mask=mask, **cross_attention_kwargs, ) diff --git a/tests/models/test_models_unet_2d_condition.py b/tests/models/test_models_unet_2d_condition.py index d3ca5ea3048e..90e60ae5d16d 100644 --- a/tests/models/test_models_unet_2d_condition.py +++ b/tests/models/test_models_unet_2d_condition.py @@ -20,6 +20,7 @@ import torch from parameterized import parameterized +from pytest import mark from diffusers import UNet2DConditionModel from diffusers.models.attention_processor import CustomDiffusionAttnProcessor, LoRAAttnProcessor @@ -417,6 +418,76 @@ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_ma assert processor.is_run assert processor.number == 123 + @parameterized.expand( + [ + # fmt: off + [torch.bool], + [torch.long], + [torch.float], + # fmt: on + ] + ) + def test_model_xattn_mask(self, mask_dtype): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + model = self.model_class(**{**init_dict, "attention_head_dim": (8, 16)}) + model.to(torch_device) + model.eval() + + cond = inputs_dict["encoder_hidden_states"] + with torch.no_grad(): + full_cond_out = model(**inputs_dict).sample + assert full_cond_out is not None + + keepall_mask = torch.ones(*cond.shape[:-1], device=cond.device, dtype=mask_dtype) + full_cond_keepallmask_out = model(**{**inputs_dict, "encoder_attention_mask": keepall_mask}).sample + assert full_cond_keepallmask_out.allclose( + full_cond_out + ), "a 'keep all' mask should give the same result as no mask" + + trunc_cond = cond[:, :-1, :] + trunc_cond_out = model(**{**inputs_dict, "encoder_hidden_states": trunc_cond}).sample + assert not trunc_cond_out.allclose( + full_cond_out + ), "discarding the last token from our cond should change the result" + + batch, tokens, _ = cond.shape + mask_last = (torch.arange(tokens) < tokens - 1).expand(batch, -1).to(cond.device, mask_dtype) + masked_cond_out = model(**{**inputs_dict, "encoder_attention_mask": mask_last}).sample + assert masked_cond_out.allclose( + trunc_cond_out + ), "masking the last token from our cond should be equivalent to truncating that token out of the condition" + + # see diffusers.models.attention_processor::Attention#prepare_attention_mask + # note: we may not need to fix mask padding to work for stable-diffusion cross-attn masks. + # since the use-case (somebody passes in a too-short cross-attn mask) is pretty esoteric. + # maybe it's fine that this only works for the unclip use-case. + @mark.skip( + reason="we currently pad mask by target_length tokens (what unclip needs), whereas stable-diffusion's cross-attn needs to instead pad by remaining_length." + ) + def test_model_xattn_padding(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + model = self.model_class(**{**init_dict, "attention_head_dim": (8, 16)}) + model.to(torch_device) + model.eval() + + cond = inputs_dict["encoder_hidden_states"] + with torch.no_grad(): + full_cond_out = model(**inputs_dict).sample + assert full_cond_out is not None + + batch, tokens, _ = cond.shape + keeplast_mask = (torch.arange(tokens) == tokens - 1).expand(batch, -1).to(cond.device, torch.bool) + keeplast_out = model(**{**inputs_dict, "encoder_attention_mask": keeplast_mask}).sample + assert not keeplast_out.allclose(full_cond_out), "a 'keep last token' mask should change the result" + + trunc_mask = torch.zeros(batch, tokens - 1, device=cond.device, dtype=torch.bool) + trunc_mask_out = model(**{**inputs_dict, "encoder_attention_mask": trunc_mask}).sample + assert trunc_mask_out.allclose( + keeplast_out + ), "a mask with fewer tokens than condition, will be padded with 'keep' tokens. a 'discard-all' mask missing the final token is thus equivalent to a 'keep last' mask." + def test_lora_processors(self): # enable deterministic behavior for gradient checkpointing init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()