From 57bdc4f175cfd712842fe0f0248d2dc80fbb6a75 Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Wed, 29 Mar 2023 22:26:24 +0100 Subject: [PATCH 01/25] Cross-attention masks prefer qualified symbol, fix accidental Optional prefer qualified symbol in AttentionProcessor prefer qualified symbol in embeddings.py qualified symbol in transformed_2d qualify FloatTensor in unet_2d_blocks move new transformer_2d params attention_mask, encoder_attention_mask to the end of the section which is assumed (e.g. by functions such as checkpoint()) to have a stable positional param interface. regard return_dict as a special-case which is assumed to be injected separately from positional params (e.g. by create_custom_forward()). move new encoder_attention_mask param to end of CrossAttn block interfaces and Unet2DCondition interface, to maintain positional param interface. regenerate modeling_text_unet.py remove unused import unet_2d_condition encoder_attention_mask docs Co-authored-by: Pedro Cuenca versatile_diffusion/modeling_text_unet.py encoder_attention_mask docs Co-authored-by: Pedro Cuenca transformer_2d encoder_attention_mask docs Co-authored-by: Pedro Cuenca unet_2d_blocks.py: add parameter name comments Co-authored-by: Pedro Cuenca revert description. bool-to-bias treatment happens in unet_2d_condition only. comment parameter names fix copies, style --- src/diffusers/models/attention.py | 18 ++- src/diffusers/models/attention_processor.py | 27 ++++- src/diffusers/models/embeddings.py | 2 +- src/diffusers/models/transformer_2d.py | 22 ++-- src/diffusers/models/unet_2d_blocks.py | 123 +++++++++++--------- src/diffusers/models/unet_2d_condition.py | 15 +++ 6 files changed, 128 insertions(+), 79 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 0b313b83d360..a7a9a472d9e9 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Any, Dict, Optional import torch import torch.nn.functional as F @@ -120,13 +120,13 @@ def __init__( def forward( self, - hidden_states, - attention_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - timestep=None, - cross_attention_kwargs=None, - class_labels=None, + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, ): # Notice that normalization is always applied before the real computation in the following blocks. # 1. Self-Attention @@ -155,8 +155,6 @@ def forward( norm_hidden_states = ( self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) ) - # TODO (Birch-San): Here we should prepare the encoder_attention mask correctly - # prepare attention mask here attn_output = self.attn2( norm_hidden_states, diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index a489814c4787..80911b6b251c 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -373,7 +373,13 @@ def prepare_attention_mask(self, attention_mask, target_length, batch_size=None, if attention_mask is None: return attention_mask - if attention_mask.shape[-1] != target_length: + current_length: int = attention_mask.shape[-1] + if current_length > target_length: + # we *could* trim the mask with: + # attention_mask = attention_mask[:,:target_length] + # but this is weird enough that it's more likely to be a mistake than a shortcut + raise ValueError(f"mask's length ({current_length}) exceeds the sequence length ({target_length}).") + elif current_length < target_length: if attention_mask.device.type == "mps": # HACK: MPS: Does not support padding by greater than dimension of input tensor. # Instead, we can manually construct the padding tensor. @@ -381,7 +387,8 @@ def prepare_attention_mask(self, attention_mask, target_length, batch_size=None, padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device) attention_mask = torch.cat([attention_mask, padding], dim=2) else: - attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + remaining_length: int = target_length - current_length + attention_mask = F.pad(attention_mask, (0, remaining_length), value=0.0) if out_dim == 3: if attention_mask.shape[0] < batch_size * head_size: @@ -813,7 +820,13 @@ class XFormersAttnProcessor: def __init__(self, attention_op: Optional[Callable] = None): self.attention_op = attention_op - def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + ): residual = hidden_states input_ndim = hidden_states.ndim @@ -822,11 +835,15 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - batch_size, sequence_length, _ = ( + batch_size, key_tokens, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size) + if attention_mask is not None: + # xformers doesn't broadcast for us, so we expand our singleton dimension manually + _, query_tokens, _ = hidden_states.shape + attention_mask = attention_mask.expand(-1, query_tokens, -1) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index fa88bce305e6..fb803039b268 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -352,7 +352,7 @@ def token_drop(self, labels, force_drop_ids=None): labels = torch.where(drop_ids, self.num_classes, labels) return labels - def forward(self, labels, force_drop_ids=None): + def forward(self, labels: torch.LongTensor, force_drop_ids=None): use_dropout = self.dropout_prob > 0 if (self.training and use_dropout) or (force_drop_ids is not None): labels = self.token_drop(labels, force_drop_ids) diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py index fde1014bd2e7..033b094bd042 100644 --- a/src/diffusers/models/transformer_2d.py +++ b/src/diffusers/models/transformer_2d.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass -from typing import Optional +from typing import Any, Dict, Optional import torch import torch.nn.functional as F @@ -213,11 +213,13 @@ def __init__( def forward( self, - hidden_states, - encoder_hidden_states=None, - timestep=None, - class_labels=None, - cross_attention_kwargs=None, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + class_labels: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, return_dict: bool = True, ): """ @@ -228,11 +230,15 @@ def forward( encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): Conditional embeddings for cross attention layer. If not given, cross-attention defaults to self-attention. - timestep ( `torch.long`, *optional*): + timestep ( `torch.LongTensor`, *optional*): Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels conditioning. + attention_mask ( `torch.Tensor` of shape (batch size, num latent pixels), *optional* ). + Bias to add to attention scores. + encoder_attention_mask ( `torch.Tensor` of shape (batch size, num encoder tokens), *optional* ). + Bias to add to cross-attention scores. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. @@ -264,7 +270,9 @@ def forward( for block in self.transformer_blocks: hidden_states = block( hidden_states, + attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, timestep=timestep, cross_attention_kwargs=cross_attention_kwargs, class_labels=class_labels, diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 75d9eb3e03df..889eeb100718 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Any, Dict, Optional, Tuple import numpy as np import torch @@ -558,14 +558,22 @@ def __init__( self.resnets = nn.ModuleList(resnets) def forward( - self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None - ): + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): - hidden_states = attn( + attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] hidden_states = resnet(hidden_states, temb) @@ -850,9 +858,14 @@ def __init__( self.gradient_checkpointing = False def forward( - self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, ): - # TODO(Patrick, William) - attention mask is not used output_states = () for resnet, attn in zip(self.resnets, self.attentions): @@ -867,33 +880,32 @@ def custom_forward(*inputs): return custom_forward - if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, use_reentrant=False - ) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), - hidden_states, - encoder_hidden_states, - cross_attention_kwargs, - use_reentrant=False, - )[0] - else: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), - hidden_states, - encoder_hidden_states, - cross_attention_kwargs, - )[0] + ckpt_kwargs: Dict[str, Any] = {'use_reentrant': False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + None, # timestep + None, # class_labels + cross_attention_kwargs, + attention_mask, + encoder_attention_mask, + **ckpt_kwargs, + )[0] else: hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] @@ -1916,15 +1928,15 @@ def __init__( def forward( self, - hidden_states, - res_hidden_states_tuple, - temb=None, - encoder_hidden_states=None, - cross_attention_kwargs=None, - upsample_size=None, - attention_mask=None, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, ): - # TODO(Patrick, William) - attention mask is not used for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] @@ -1942,33 +1954,32 @@ def custom_forward(*inputs): return custom_forward - if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, use_reentrant=False - ) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), - hidden_states, - encoder_hidden_states, - cross_attention_kwargs, - use_reentrant=False, - )[0] - else: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), - hidden_states, - encoder_hidden_states, - cross_attention_kwargs, - )[0] + ckpt_kwargs: Dict[str, Any] = {'use_reentrant': False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + None, # timestep + None, # class_labels + cross_attention_kwargs, + attention_mask, + encoder_attention_mask, + **ckpt_kwargs, + )[0] else: hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 2a4c9fd72c1b..0396ed022e80 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -618,6 +618,7 @@ def forward( cross_attention_kwargs: Optional[Dict[str, Any]] = None, down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, mid_block_additional_residual: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[UNet2DConditionOutput, Tuple]: r""" @@ -625,6 +626,10 @@ def forward( sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + encoder_attention_mask (`torch.Tensor`): + (batch, sequence_length) cross-attention mask (or bias), applied to encoder_hidden_states. If a + BoolTensor is provided, it will be turned into a bias, by adding a large negative value. False = hide + token. Other tensor types will be used as-is as bias values. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): @@ -656,6 +661,13 @@ def forward( attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) + # ensure encoder_attention_mask is a bias, and make it broadcastable over multi-head-attention channels + if encoder_attention_mask is not None: + # if it's a mask: turn it into a bias. otherwise: assume it's already a bias + if encoder_attention_mask.dtype is torch.bool: + encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + # 0. center input if necessary if self.config.center_input_sample: sample = 2 * sample - 1.0 @@ -727,6 +739,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb) @@ -752,6 +765,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, ) if mid_block_additional_residual is not None: @@ -778,6 +792,7 @@ def forward( cross_attention_kwargs=cross_attention_kwargs, upsample_size=upsample_size, attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, ) else: sample = upsample_block( From bd763a44577e247c6a87874e4cd0afe3bf37d302 Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Wed, 10 May 2023 23:33:07 +0100 Subject: [PATCH 02/25] encoder_attention_mask for SimpleCrossAttnDownBlock2D, SimpleCrossAttnUpBlock2D --- src/diffusers/models/unet_2d_blocks.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 889eeb100718..8df1e9386821 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -1513,7 +1513,13 @@ def __init__( self.gradient_checkpointing = False def forward( - self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, ): output_states = () cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} @@ -2605,13 +2611,14 @@ def __init__( def forward( self, - hidden_states, - res_hidden_states_tuple, - temb=None, - encoder_hidden_states=None, - upsample_size=None, - attention_mask=None, - cross_attention_kwargs=None, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, ): cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} for resnet, attn in zip(self.resnets, self.attentions): From d2f99b9de887caff6b6947691c303206e1118949 Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Wed, 10 May 2023 23:39:50 +0100 Subject: [PATCH 03/25] encoder_attention_mask for UNetMidBlock2DSimpleCrossAttn --- src/diffusers/models/unet_2d_blocks.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 8df1e9386821..c2bbee4b57ab 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -667,7 +667,13 @@ def __init__( self.resnets = nn.ModuleList(resnets) def forward( - self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, ): cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} hidden_states = self.resnets[0](hidden_states, temb) @@ -676,7 +682,7 @@ def forward( hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, + attention_mask=attention_mask if encoder_hidden_states is None else encoder_attention_mask, **cross_attention_kwargs, ) From 9ad3ed98849ed89ee1c3542911336b51b60a19af Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Wed, 10 May 2023 23:56:10 +0100 Subject: [PATCH 04/25] support attention_mask, encoder_attention_mask in KCrossAttnDownBlock2D, KCrossAttnUpBlock2D, KAttentionBlock. fix binding of attention_mask, cross_attention_kwargs params in KCrossAttnDownBlock2D, KCrossAttnUpBlock2D checkpoint invocations. --- src/diffusers/models/unet_2d_blocks.py | 118 ++++++++++++------------- 1 file changed, 59 insertions(+), 59 deletions(-) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index c2bbee4b57ab..c4c69390cfab 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -1714,7 +1714,13 @@ def __init__( self.gradient_checkpointing = False def forward( - self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, ): output_states = () @@ -1730,29 +1736,23 @@ def custom_forward(*inputs): return custom_forward - if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, use_reentrant=False - ) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), - hidden_states, - encoder_hidden_states, - attention_mask, - cross_attention_kwargs, - use_reentrant=False, - ) - else: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), - hidden_states, - encoder_hidden_states, - attention_mask, - cross_attention_kwargs, - ) + ckpt_kwargs: Dict[str, Any] = {'use_reentrant': False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + None, # emb + attention_mask, + cross_attention_kwargs, + encoder_attention_mask, + **ckpt_kwargs, + ) else: hidden_states = resnet(hidden_states, temb) hidden_states = attn( @@ -1761,6 +1761,7 @@ def custom_forward(*inputs): emb=temb, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, ) if self.downsamplers is None: @@ -2835,13 +2836,14 @@ def __init__( def forward( self, - hidden_states, - res_hidden_states_tuple, - temb=None, - encoder_hidden_states=None, - cross_attention_kwargs=None, - upsample_size=None, - attention_mask=None, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, ): res_hidden_states_tuple = res_hidden_states_tuple[-1] if res_hidden_states_tuple is not None: @@ -2859,29 +2861,23 @@ def custom_forward(*inputs): return custom_forward - if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, use_reentrant=False - ) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), - hidden_states, - encoder_hidden_states, - attention_mask, - cross_attention_kwargs, - use_reentrant=False, - )[0] - else: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), - hidden_states, - encoder_hidden_states, - attention_mask, - cross_attention_kwargs, - )[0] + ckpt_kwargs: Dict[str, Any] = {'use_reentrant': False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + None, # emb + attention_mask, + cross_attention_kwargs, + encoder_attention_mask, + **ckpt_kwargs, + )[0] else: hidden_states = resnet(hidden_states, temb) hidden_states = attn( @@ -2890,6 +2886,7 @@ def custom_forward(*inputs): emb=temb, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, ) if self.upsamplers is not None: @@ -2968,11 +2965,12 @@ def _to_4d(self, hidden_states, height, weight): def forward( self, - hidden_states, - encoder_hidden_states=None, - emb=None, - attention_mask=None, - cross_attention_kwargs=None, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + emb: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, ): cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} @@ -2986,6 +2984,7 @@ def forward( attn_output = self.attn1( norm_hidden_states, encoder_hidden_states=None, + attention_mask=attention_mask, **cross_attention_kwargs, ) attn_output = self._to_4d(attn_output, height, weight) @@ -3000,6 +2999,7 @@ def forward( attn_output = self.attn2( norm_hidden_states, encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask if encoder_hidden_states is None else encoder_attention_mask, **cross_attention_kwargs, ) attn_output = self._to_4d(attn_output, height, weight) From fecc5953a85510adef47b183a1f66892ee317610 Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Thu, 11 May 2023 22:20:14 +0100 Subject: [PATCH 05/25] fix mistake made during merge conflict resolution --- src/diffusers/models/unet_2d_blocks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index c4c69390cfab..480b500ad255 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -568,7 +568,7 @@ def forward( ) -> torch.FloatTensor: hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): - attn( + hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, From 370daf57918440ccb8b1e5cddda7620b8b9b936a Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Thu, 11 May 2023 22:20:36 +0100 Subject: [PATCH 06/25] regenerate versatile_diffusion --- .../versatile_diffusion/modeling_text_unet.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 7aaa0e49e1da..acf4b242af9f 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -1666,7 +1666,13 @@ def __init__( self.resnets = nn.ModuleList(resnets) def forward( - self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, ): cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} hidden_states = self.resnets[0](hidden_states, temb) @@ -1675,7 +1681,7 @@ def forward( hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, + attention_mask=attention_mask if encoder_hidden_states is None else encoder_attention_mask, **cross_attention_kwargs, ) From b60b7ed70dc7fe7cfdc86e6d95f8bec95ac893eb Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Thu, 11 May 2023 22:34:26 +0100 Subject: [PATCH 07/25] pass time embedding into checkpointed attention invocation --- src/diffusers/models/unet_2d_blocks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 480b500ad255..91f757cc28ba 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -1747,7 +1747,7 @@ def custom_forward(*inputs): create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, - None, # emb + temb, attention_mask, cross_attention_kwargs, encoder_attention_mask, @@ -2872,7 +2872,7 @@ def custom_forward(*inputs): create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, - None, # emb + temb, attention_mask, cross_attention_kwargs, encoder_attention_mask, From faef7acd44b0939e8f7d4552e50fb6079387f666 Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Thu, 11 May 2023 22:59:51 +0100 Subject: [PATCH 08/25] always assume encoder_attention_mask is a mask (i.e. not a bias). --- src/diffusers/models/unet_2d_condition.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 0396ed022e80..e751a52c60a5 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -627,9 +627,9 @@ def forward( timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states encoder_attention_mask (`torch.Tensor`): - (batch, sequence_length) cross-attention mask (or bias), applied to encoder_hidden_states. If a - BoolTensor is provided, it will be turned into a bias, by adding a large negative value. False = hide - token. Other tensor types will be used as-is as bias values. + (batch, sequence_length) cross-attention mask, applied to encoder_hidden_states. True = keep, False = discard. + Mask will be converted into a bias, which adds large negative values to attention scores corresponding to + "discard" tokens. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): @@ -656,16 +656,18 @@ def forward( logger.info("Forward upsample size to force interpolation output size.") forward_upsample_size = True - # prepare attention_mask + # ensure attention_mask is a bias, and make it broadcastable over multi-head-attention channels if attention_mask is not None: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) - # ensure encoder_attention_mask is a bias, and make it broadcastable over multi-head-attention channels + # convert encoder_attention_mask to a bias the same way we do for attention_mask if encoder_attention_mask is not None: - # if it's a mask: turn it into a bias. otherwise: assume it's already a bias - if encoder_attention_mask.dtype is torch.bool: - encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 + encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) # 0. center input if necessary From 725c27adf479e34253afd4fd4194e89317aa09c3 Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Thu, 11 May 2023 23:26:40 +0100 Subject: [PATCH 09/25] style, fix-copies --- src/diffusers/models/unet_2d_condition.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index e751a52c60a5..19d6db2b1065 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -627,9 +627,9 @@ def forward( timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states encoder_attention_mask (`torch.Tensor`): - (batch, sequence_length) cross-attention mask, applied to encoder_hidden_states. True = keep, False = discard. - Mask will be converted into a bias, which adds large negative values to attention scores corresponding to - "discard" tokens. + (batch, sequence_length) cross-attention mask, applied to encoder_hidden_states. True = keep, False = + discard. Mask will be converted into a bias, which adds large negative values to attention scores + corresponding to "discard" tokens. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): From e0437ae0b56cd08eb2a315c9ba0a7ae91e9c8e69 Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Fri, 12 May 2023 00:38:57 +0100 Subject: [PATCH 10/25] add tests for cross-attention masks --- tests/models/test_models_unet_2d_condition.py | 40 +++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/tests/models/test_models_unet_2d_condition.py b/tests/models/test_models_unet_2d_condition.py index d3ca5ea3048e..8a3c35345f7c 100644 --- a/tests/models/test_models_unet_2d_condition.py +++ b/tests/models/test_models_unet_2d_condition.py @@ -417,6 +417,46 @@ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_ma assert processor.is_run assert processor.number == 123 + @parameterized.expand( + [ + # fmt: off + [torch.bool], + [torch.long], + [torch.float], + # fmt: on + ] + ) + def test_model_xattn_mask(self, mask_dtype): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + model = self.model_class(**{**init_dict, "attention_head_dim": (8, 16)}) + model.to(torch_device) + model.eval() + + cond = inputs_dict["encoder_hidden_states"] + with torch.no_grad(): + full_cond_out = model(**inputs_dict).sample + assert full_cond_out is not None + + keepall_mask = torch.ones(*cond.shape[:-1]).to(cond.device, mask_dtype) + full_cond_keepallmask_out = model(**{**inputs_dict, "encoder_attention_mask": keepall_mask}).sample + assert full_cond_keepallmask_out.allclose( + full_cond_out + ), "a 'keep all' mask should give the same result as no mask" + + trunc_cond = cond[:, :-1, :] + trunc_cond_out = model(**{**inputs_dict, "encoder_hidden_states": trunc_cond}).sample + assert not trunc_cond_out.allclose( + full_cond_out + ), "discarding the last token from our cond should change the result" + + batch, tokens, _ = cond.shape + trunc_mask = (torch.arange(tokens) < tokens - 1).expand(batch, -1).to(cond.device, mask_dtype) + masked_cond_out = model(**{**inputs_dict, "encoder_attention_mask": trunc_mask}).sample + assert masked_cond_out.allclose( + trunc_cond_out + ), "masking the last token from our cond should be equivalent to truncating that token out of the condition" + def test_lora_processors(self): # enable deterministic behavior for gradient checkpointing init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() From 1a68b6523fa5387a2fafd7ae0952650acf105bf4 Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Fri, 12 May 2023 21:27:43 +0100 Subject: [PATCH 11/25] add test for padding of attention mask --- tests/models/test_models_unet_2d_condition.py | 29 +++++++++++++++++-- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/tests/models/test_models_unet_2d_condition.py b/tests/models/test_models_unet_2d_condition.py index 8a3c35345f7c..d5f54d872ee6 100644 --- a/tests/models/test_models_unet_2d_condition.py +++ b/tests/models/test_models_unet_2d_condition.py @@ -438,7 +438,7 @@ def test_model_xattn_mask(self, mask_dtype): full_cond_out = model(**inputs_dict).sample assert full_cond_out is not None - keepall_mask = torch.ones(*cond.shape[:-1]).to(cond.device, mask_dtype) + keepall_mask = torch.ones(*cond.shape[:-1], device=cond.device, dtype=mask_dtype) full_cond_keepallmask_out = model(**{**inputs_dict, "encoder_attention_mask": keepall_mask}).sample assert full_cond_keepallmask_out.allclose( full_cond_out @@ -451,12 +451,35 @@ def test_model_xattn_mask(self, mask_dtype): ), "discarding the last token from our cond should change the result" batch, tokens, _ = cond.shape - trunc_mask = (torch.arange(tokens) < tokens - 1).expand(batch, -1).to(cond.device, mask_dtype) - masked_cond_out = model(**{**inputs_dict, "encoder_attention_mask": trunc_mask}).sample + mask_last = (torch.arange(tokens) < tokens - 1).expand(batch, -1).to(cond.device, mask_dtype) + masked_cond_out = model(**{**inputs_dict, "encoder_attention_mask": mask_last}).sample assert masked_cond_out.allclose( trunc_cond_out ), "masking the last token from our cond should be equivalent to truncating that token out of the condition" + def test_model_xattn_padding(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + model = self.model_class(**{**init_dict, "attention_head_dim": (8, 16)}) + model.to(torch_device) + model.eval() + + cond = inputs_dict["encoder_hidden_states"] + with torch.no_grad(): + full_cond_out = model(**inputs_dict).sample + assert full_cond_out is not None + + batch, tokens, _ = cond.shape + keeplast_mask = (torch.arange(tokens) == tokens - 1).expand(batch, -1).to(cond.device, torch.bool) + keeplast_out = model(**{**inputs_dict, "encoder_attention_mask": keeplast_mask}).sample + assert not keeplast_out.allclose(full_cond_out), "a 'keep last token' mask should change the result" + + trunc_mask = torch.zeros(batch, tokens - 1, device=cond.device, dtype=torch.bool) + trunc_mask_out = model(**{**inputs_dict, "encoder_attention_mask": trunc_mask}).sample + assert trunc_mask_out.allclose( + keeplast_out + ), "a mask with fewer tokens than condition, will be padded with 'keep' tokens. a 'discard-all' mask missing the final token is thus equivalent to a 'keep last' mask." + def test_lora_processors(self): # enable deterministic behavior for gradient checkpointing init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() From ea2948e4021fcc741fd8a1f806c9bef16b62d436 Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Fri, 12 May 2023 22:15:27 +0100 Subject: [PATCH 12/25] explain mask's query_tokens dim. fix explanation about broadcasting over channels; we actually broadcast over query tokens --- src/diffusers/models/attention_processor.py | 7 ++++++- src/diffusers/models/unet_2d_condition.py | 9 ++++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 80911b6b251c..e02cebc3eb90 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -841,7 +841,12 @@ def __call__( attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size) if attention_mask is not None: - # xformers doesn't broadcast for us, so we expand our singleton dimension manually + # expand our mask's singleton query_tokens dimension: + # [batch*heads, 1, key_tokens] -> + # [batch*heads, query_tokens, key_tokens] + # so that it can be added as a bias onto the attention scores that xformers computes: + # [batch*heads, query_tokens, key_tokens] + # we do this explicitly because xformers doesn't broadcast the singleton dimension for us. _, query_tokens, _ = hidden_states.shape attention_mask = attention_mask.expand(-1, query_tokens, -1) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 19d6db2b1065..76a40ffa1ec5 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -656,7 +656,14 @@ def forward( logger.info("Forward upsample size to force interpolation output size.") forward_upsample_size = True - # ensure attention_mask is a bias, and make it broadcastable over multi-head-attention channels + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) if attention_mask is not None: # assume that mask is expressed as: # (1 = keep, 0 = discard) From 786249eebc4e69f4bf2e11e2a903681b5ca10c3f Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Fri, 12 May 2023 22:43:33 +0100 Subject: [PATCH 13/25] support both masks and biases in Transformer2DModel#forward. document behaviour --- src/diffusers/models/unet_2d_condition.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 76a40ffa1ec5..06fc32ab37c2 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -626,6 +626,10 @@ def forward( sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + attention_mask (`torch.Tensor`): + (batch, sequence_length) self-attention mask, applied to sample. True = keep, False = discard. Mask + will be converted into a bias, which adds large negative values to attention scores corresponding to + "discard" tokens. encoder_attention_mask (`torch.Tensor`): (batch, sequence_length) cross-attention mask, applied to encoder_hidden_states. True = keep, False = discard. Mask will be converted into a bias, which adds large negative values to attention scores From 23440ef8d14122a5e481ed4a5a70622e6fa4ce17 Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Fri, 12 May 2023 22:43:48 +0100 Subject: [PATCH 14/25] fix-copies --- src/diffusers/models/transformer_2d.py | 33 ++++++++++++++++++++++---- 1 file changed, 29 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py index 033b094bd042..ec4cb371845f 100644 --- a/src/diffusers/models/transformer_2d.py +++ b/src/diffusers/models/transformer_2d.py @@ -235,10 +235,12 @@ def forward( class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels conditioning. - attention_mask ( `torch.Tensor` of shape (batch size, num latent pixels), *optional* ). - Bias to add to attention scores. - encoder_attention_mask ( `torch.Tensor` of shape (batch size, num encoder tokens), *optional* ). - Bias to add to cross-attention scores. + encoder_attention_mask ( `torch.Tensor`, *optional* ). + Cross-attention mask, applied to encoder_hidden_states. Two formats supported: + Mask `(batch, sequence_length)` True = keep, False = discard. Bias `(batch, 1, sequence_length)` 0 + = keep, -10000 = discard. + If ndim == 2: will be interpreted as a mask, then converted into a bias consistent with the format + above. This bias will be added to the cross-attention scores. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. @@ -247,6 +249,29 @@ def forward( [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None and attention_mask.ndim == 2: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + # 1. Input if self.is_input_continuous: batch, _, height, width = hidden_states.shape From 915521571e2c1869fb7c7d820fd164db7d385bfb Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Fri, 12 May 2023 22:50:36 +0100 Subject: [PATCH 15/25] delete attention_mask docs on the basis I never tested self-attention masking myself. not comfortable explaining it, since I don't actually understand how a self-attn mask can work in its current form: the key length will be different in every ResBlock (we don't downsample the mask when we downsample the image). --- src/diffusers/models/unet_2d_condition.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 06fc32ab37c2..76a40ffa1ec5 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -626,10 +626,6 @@ def forward( sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states - attention_mask (`torch.Tensor`): - (batch, sequence_length) self-attention mask, applied to sample. True = keep, False = discard. Mask - will be converted into a bias, which adds large negative values to attention scores corresponding to - "discard" tokens. encoder_attention_mask (`torch.Tensor`): (batch, sequence_length) cross-attention mask, applied to encoder_hidden_states. True = keep, False = discard. Mask will be converted into a bias, which adds large negative values to attention scores From 479574e556fa56905267cf7d9bf421698020bbff Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Thu, 18 May 2023 23:56:29 +0100 Subject: [PATCH 16/25] review feedback: the standard Unet blocks shouldn't pass temb to attn (only to resnet). remove from KCrossAttnDownBlock2D,KCrossAttnUpBlock2D#forward. --- src/diffusers/models/unet_2d_blocks.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 91f757cc28ba..ae45f3a82c4e 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -1747,7 +1747,7 @@ def custom_forward(*inputs): create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, - temb, + None, # emb attention_mask, cross_attention_kwargs, encoder_attention_mask, @@ -1758,7 +1758,6 @@ def custom_forward(*inputs): hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, - emb=temb, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, @@ -2872,7 +2871,7 @@ def custom_forward(*inputs): create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, - temb, + None, # temb attention_mask, cross_attention_kwargs, encoder_attention_mask, @@ -2883,7 +2882,6 @@ def custom_forward(*inputs): hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, - emb=temb, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, From 7e9ba8f3b147990a653afc1c5bde4429db2a1d39 Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Fri, 19 May 2023 00:15:54 +0100 Subject: [PATCH 17/25] remove encoder_attention_mask param from SimpleCrossAttn{Up,Down}Block2D,UNetMidBlock2DSimpleCrossAttn, and mask-choice in those blocks' #forward, on the basis that they only do one type of attention, so the consumer can pass whichever type of attention_mask is appropriate. --- src/diffusers/models/unet_2d_blocks.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index ae45f3a82c4e..392c3357bc63 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -673,7 +673,6 @@ def forward( encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, ): cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} hidden_states = self.resnets[0](hidden_states, temb) @@ -682,7 +681,7 @@ def forward( hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask if encoder_hidden_states is None else encoder_attention_mask, + attention_mask=attention_mask, **cross_attention_kwargs, ) @@ -1525,7 +1524,6 @@ def forward( encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, ): output_states = () cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} @@ -1547,6 +1545,7 @@ def custom_forward(*inputs): create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, + attention_mask, cross_attention_kwargs, )[0] else: @@ -2624,7 +2623,6 @@ def forward( upsample_size: Optional[int] = None, attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, ): cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} for resnet, attn in zip(self.resnets, self.attentions): @@ -2650,6 +2648,7 @@ def custom_forward(*inputs): create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, + attention_mask, cross_attention_kwargs, )[0] else: From ba3da64ab9f0be79bfd6fc8185b6e81357eb7b52 Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Fri, 19 May 2023 00:29:31 +0100 Subject: [PATCH 18/25] put attention mask padding back to how it was (since the SD use-case it enabled wasn't important, and it breaks the original unclip use-case). disable the test which was added. --- src/diffusers/models/attention_processor.py | 7 +++++-- tests/models/test_models_unet_2d_condition.py | 8 ++++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index e02cebc3eb90..8bb3a6f2ad8f 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -387,8 +387,11 @@ def prepare_attention_mask(self, attention_mask, target_length, batch_size=None, padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device) attention_mask = torch.cat([attention_mask, padding], dim=2) else: - remaining_length: int = target_length - current_length - attention_mask = F.pad(attention_mask, (0, remaining_length), value=0.0) + # TODO: for pipelines such as stable-diffusion, padding cross-attn mask: + # we want to instead pad by (0, remaining_length), where remaining_length is: + # remaining_length: int = target_length - current_length + # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) if out_dim == 3: if attention_mask.shape[0] < batch_size * head_size: diff --git a/tests/models/test_models_unet_2d_condition.py b/tests/models/test_models_unet_2d_condition.py index d5f54d872ee6..90e60ae5d16d 100644 --- a/tests/models/test_models_unet_2d_condition.py +++ b/tests/models/test_models_unet_2d_condition.py @@ -20,6 +20,7 @@ import torch from parameterized import parameterized +from pytest import mark from diffusers import UNet2DConditionModel from diffusers.models.attention_processor import CustomDiffusionAttnProcessor, LoRAAttnProcessor @@ -457,6 +458,13 @@ def test_model_xattn_mask(self, mask_dtype): trunc_cond_out ), "masking the last token from our cond should be equivalent to truncating that token out of the condition" + # see diffusers.models.attention_processor::Attention#prepare_attention_mask + # note: we may not need to fix mask padding to work for stable-diffusion cross-attn masks. + # since the use-case (somebody passes in a too-short cross-attn mask) is pretty esoteric. + # maybe it's fine that this only works for the unclip use-case. + @mark.skip( + reason="we currently pad mask by target_length tokens (what unclip needs), whereas stable-diffusion's cross-attn needs to instead pad by remaining_length." + ) def test_model_xattn_padding(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() From 0c020cf22f3cd777b2cb518b48e47ebe14d74cec Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Fri, 19 May 2023 00:29:49 +0100 Subject: [PATCH 19/25] fix-copies --- .../pipelines/versatile_diffusion/modeling_text_unet.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index acf4b242af9f..ceee1973e5e7 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -1672,7 +1672,6 @@ def forward( encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, ): cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} hidden_states = self.resnets[0](hidden_states, temb) @@ -1681,7 +1680,7 @@ def forward( hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask if encoder_hidden_states is None else encoder_attention_mask, + attention_mask=attention_mask, **cross_attention_kwargs, ) From 3ed583f0ab3794cabfd3b635939e547ccb0486d5 Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Fri, 19 May 2023 00:49:41 +0100 Subject: [PATCH 20/25] style --- src/diffusers/models/unet_2d_blocks.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 392c3357bc63..cb158af3ac9c 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -885,7 +885,7 @@ def custom_forward(*inputs): return custom_forward - ckpt_kwargs: Dict[str, Any] = {'use_reentrant': False} if is_torch_version(">=", "1.11.0") else {} + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, @@ -1735,7 +1735,7 @@ def custom_forward(*inputs): return custom_forward - ckpt_kwargs: Dict[str, Any] = {'use_reentrant': False} if is_torch_version(">=", "1.11.0") else {} + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, @@ -1965,7 +1965,7 @@ def custom_forward(*inputs): return custom_forward - ckpt_kwargs: Dict[str, Any] = {'use_reentrant': False} if is_torch_version(">=", "1.11.0") else {} + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, @@ -2859,7 +2859,7 @@ def custom_forward(*inputs): return custom_forward - ckpt_kwargs: Dict[str, Any] = {'use_reentrant': False} if is_torch_version(">=", "1.11.0") else {} + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, From c4931a0468ff72f9a1b933b36fd5d5c797f3f4b8 Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Fri, 19 May 2023 00:49:48 +0100 Subject: [PATCH 21/25] fix-copies --- .../versatile_diffusion/modeling_text_unet.py | 145 +++++++++++------- 1 file changed, 90 insertions(+), 55 deletions(-) diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index ceee1973e5e7..dceaacc270d5 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -721,6 +721,7 @@ def forward( cross_attention_kwargs: Optional[Dict[str, Any]] = None, down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, mid_block_additional_residual: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[UNet2DConditionOutput, Tuple]: r""" @@ -728,6 +729,10 @@ def forward( sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + encoder_attention_mask (`torch.Tensor`): + (batch, sequence_length) cross-attention mask, applied to encoder_hidden_states. True = keep, False = + discard. Mask will be converted into a bias, which adds large negative values to attention scores + corresponding to "discard" tokens. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): @@ -754,11 +759,27 @@ def forward( logger.info("Forward upsample size to force interpolation output size.") forward_upsample_size = True - # prepare attention_mask + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) if attention_mask is not None: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None: + encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + # 0. center input if necessary if self.config.center_input_sample: sample = 2 * sample - 1.0 @@ -830,6 +851,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb) @@ -855,6 +877,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, ) if mid_block_additional_residual is not None: @@ -881,6 +904,7 @@ def forward( cross_attention_kwargs=cross_attention_kwargs, upsample_size=upsample_size, attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, ) else: sample = upsample_block( @@ -1188,9 +1212,14 @@ def __init__( self.gradient_checkpointing = False def forward( - self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, ): - # TODO(Patrick, William) - attention mask is not used output_states = () for resnet, attn in zip(self.resnets, self.attentions): @@ -1205,33 +1234,32 @@ def custom_forward(*inputs): return custom_forward - if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, use_reentrant=False - ) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), - hidden_states, - encoder_hidden_states, - cross_attention_kwargs, - use_reentrant=False, - )[0] - else: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), - hidden_states, - encoder_hidden_states, - cross_attention_kwargs, - )[0] + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + None, # timestep + None, # class_labels + cross_attention_kwargs, + attention_mask, + encoder_attention_mask, + **ckpt_kwargs, + )[0] else: hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] @@ -1414,15 +1442,15 @@ def __init__( def forward( self, - hidden_states, - res_hidden_states_tuple, - temb=None, - encoder_hidden_states=None, - cross_attention_kwargs=None, - upsample_size=None, - attention_mask=None, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, ): - # TODO(Patrick, William) - attention mask is not used for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] @@ -1440,33 +1468,32 @@ def custom_forward(*inputs): return custom_forward - if is_torch_version(">=", "1.11.0"): - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, use_reentrant=False - ) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), - hidden_states, - encoder_hidden_states, - cross_attention_kwargs, - use_reentrant=False, - )[0] - else: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb - ) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), - hidden_states, - encoder_hidden_states, - cross_attention_kwargs, - )[0] + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + None, # timestep + None, # class_labels + cross_attention_kwargs, + attention_mask, + encoder_attention_mask, + **ckpt_kwargs, + )[0] else: hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] @@ -1564,14 +1591,22 @@ def __init__( self.resnets = nn.ModuleList(resnets) def forward( - self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None - ): + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] hidden_states = resnet(hidden_states, temb) From 0d78e8b9cbef25758027114f7266bb98cd71cf96 Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Fri, 19 May 2023 01:03:31 +0100 Subject: [PATCH 22/25] put encoder_attention_mask param back into Simple block forward interfaces, to ensure consistency of forward interface. --- src/diffusers/models/unet_2d_blocks.py | 6 ++++++ .../pipelines/versatile_diffusion/modeling_text_unet.py | 2 ++ 2 files changed, 8 insertions(+) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index cb158af3ac9c..51c66ca59a68 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -673,6 +673,8 @@ def forward( encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, + # parameter exists only for interface-compatibility with other blocks. prefer attention_mask + encoder_attention_mask: Optional[torch.FloatTensor] = None, ): cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} hidden_states = self.resnets[0](hidden_states, temb) @@ -1524,6 +1526,8 @@ def forward( encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, + # parameter exists only for interface-compatibility with other blocks. prefer attention_mask + encoder_attention_mask: Optional[torch.FloatTensor] = None, ): output_states = () cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} @@ -2623,6 +2627,8 @@ def forward( upsample_size: Optional[int] = None, attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, + # parameter exists only for interface-compatibility with other blocks. prefer attention_mask + encoder_attention_mask: Optional[torch.FloatTensor] = None, ): cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} for resnet, attn in zip(self.resnets, self.attentions): diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index dceaacc270d5..303c11da1ee3 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -1707,6 +1707,8 @@ def forward( encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, + # parameter exists only for interface-compatibility with other blocks. prefer attention_mask + encoder_attention_mask: Optional[torch.FloatTensor] = None, ): cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} hidden_states = self.resnets[0](hidden_states, temb) From 937ab7020123142e4e156945927779cda469932a Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Sat, 20 May 2023 11:28:55 +0100 Subject: [PATCH 23/25] restore passing of emb to KAttentionBlock#forward, on the basis that removal caused test failures. restore also the passing of emb to checkpointed calls to KAttentionBlock#forward. --- src/diffusers/models/unet_2d_blocks.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 51c66ca59a68..9d753447d44f 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -1750,7 +1750,7 @@ def custom_forward(*inputs): create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, - None, # emb + temb, attention_mask, cross_attention_kwargs, encoder_attention_mask, @@ -1761,6 +1761,7 @@ def custom_forward(*inputs): hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, + emb=temb, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, @@ -2876,7 +2877,7 @@ def custom_forward(*inputs): create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, - None, # temb + temb, attention_mask, cross_attention_kwargs, encoder_attention_mask, @@ -2887,6 +2888,7 @@ def custom_forward(*inputs): hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, + emb=temb, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, @@ -2970,6 +2972,8 @@ def forward( self, hidden_states: torch.FloatTensor, encoder_hidden_states: Optional[torch.FloatTensor] = None, + # TODO: mark emb as non-optional (self.norm2 requires it). + # requires assessing impact of change to positional param interface. emb: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, From 83762b27087a753a832e1c0105f82ffb69e31475 Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Sat, 20 May 2023 12:22:15 +0100 Subject: [PATCH 24/25] make simple unet2d blocks use encoder_attention_mask, but only when attention_mask is None. this should fix UnCLIP compatibility. --- src/diffusers/models/unet_2d_blocks.py | 48 +++++++++++++++++++++----- 1 file changed, 40 insertions(+), 8 deletions(-) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 9d753447d44f..6f8e3d0f5500 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -673,17 +673,28 @@ def forward( encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, - # parameter exists only for interface-compatibility with other blocks. prefer attention_mask encoder_attention_mask: Optional[torch.FloatTensor] = None, ): cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + + if attention_mask is None: + # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. + mask = None if encoder_hidden_states is None else encoder_attention_mask + else: + # when attention_mask is defined: we don't even check for encoder_attention_mask. + # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks. + # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask. + # then we can simplify this whole if/else block to: + # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask + mask = attention_mask + hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): # attn hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, + attention_mask=mask, **cross_attention_kwargs, ) @@ -1526,12 +1537,22 @@ def forward( encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, - # parameter exists only for interface-compatibility with other blocks. prefer attention_mask encoder_attention_mask: Optional[torch.FloatTensor] = None, ): output_states = () cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + if attention_mask is None: + # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. + mask = None if encoder_hidden_states is None else encoder_attention_mask + else: + # when attention_mask is defined: we don't even check for encoder_attention_mask. + # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks. + # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask. + # then we can simplify this whole if/else block to: + # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask + mask = attention_mask + for resnet, attn in zip(self.resnets, self.attentions): if self.training and self.gradient_checkpointing: @@ -1549,7 +1570,7 @@ def custom_forward(*inputs): create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, - attention_mask, + mask, cross_attention_kwargs, )[0] else: @@ -1558,7 +1579,7 @@ def custom_forward(*inputs): hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, + attention_mask=mask, **cross_attention_kwargs, ) @@ -2628,10 +2649,21 @@ def forward( upsample_size: Optional[int] = None, attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, - # parameter exists only for interface-compatibility with other blocks. prefer attention_mask encoder_attention_mask: Optional[torch.FloatTensor] = None, ): cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + + if attention_mask is None: + # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. + mask = None if encoder_hidden_states is None else encoder_attention_mask + else: + # when attention_mask is defined: we don't even check for encoder_attention_mask. + # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks. + # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask. + # then we can simplify this whole if/else block to: + # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask + mask = attention_mask + for resnet, attn in zip(self.resnets, self.attentions): # resnet # pop res hidden states @@ -2655,7 +2687,7 @@ def custom_forward(*inputs): create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, - attention_mask, + mask, cross_attention_kwargs, )[0] else: @@ -2664,7 +2696,7 @@ def custom_forward(*inputs): hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, + attention_mask=mask, **cross_attention_kwargs, ) From 8c323c5fc84e8874ebbdff0a5afee3794471039c Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Sat, 20 May 2023 12:22:23 +0100 Subject: [PATCH 25/25] fix copies --- .../versatile_diffusion/modeling_text_unet.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 303c11da1ee3..29cde43337d2 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -1707,17 +1707,28 @@ def forward( encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, - # parameter exists only for interface-compatibility with other blocks. prefer attention_mask encoder_attention_mask: Optional[torch.FloatTensor] = None, ): cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + + if attention_mask is None: + # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. + mask = None if encoder_hidden_states is None else encoder_attention_mask + else: + # when attention_mask is defined: we don't even check for encoder_attention_mask. + # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks. + # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask. + # then we can simplify this whole if/else block to: + # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask + mask = attention_mask + hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): # attn hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, + attention_mask=mask, **cross_attention_kwargs, )