Skip to content

Commit 261ccf6

Browse files
authored
Support for cross-attention bias / mask (huggingface#2634)
* Cross-attention masks prefer qualified symbol, fix accidental Optional prefer qualified symbol in AttentionProcessor prefer qualified symbol in embeddings.py qualified symbol in transformed_2d qualify FloatTensor in unet_2d_blocks move new transformer_2d params attention_mask, encoder_attention_mask to the end of the section which is assumed (e.g. by functions such as checkpoint()) to have a stable positional param interface. regard return_dict as a special-case which is assumed to be injected separately from positional params (e.g. by create_custom_forward()). move new encoder_attention_mask param to end of CrossAttn block interfaces and Unet2DCondition interface, to maintain positional param interface. regenerate modeling_text_unet.py remove unused import unet_2d_condition encoder_attention_mask docs Co-authored-by: Pedro Cuenca <[email protected]> versatile_diffusion/modeling_text_unet.py encoder_attention_mask docs Co-authored-by: Pedro Cuenca <[email protected]> transformer_2d encoder_attention_mask docs Co-authored-by: Pedro Cuenca <[email protected]> unet_2d_blocks.py: add parameter name comments Co-authored-by: Pedro Cuenca <[email protected]> revert description. bool-to-bias treatment happens in unet_2d_condition only. comment parameter names fix copies, style * encoder_attention_mask for SimpleCrossAttnDownBlock2D, SimpleCrossAttnUpBlock2D * encoder_attention_mask for UNetMidBlock2DSimpleCrossAttn * support attention_mask, encoder_attention_mask in KCrossAttnDownBlock2D, KCrossAttnUpBlock2D, KAttentionBlock. fix binding of attention_mask, cross_attention_kwargs params in KCrossAttnDownBlock2D, KCrossAttnUpBlock2D checkpoint invocations. * fix mistake made during merge conflict resolution * regenerate versatile_diffusion * pass time embedding into checkpointed attention invocation * always assume encoder_attention_mask is a mask (i.e. not a bias). * style, fix-copies * add tests for cross-attention masks * add test for padding of attention mask * explain mask's query_tokens dim. fix explanation about broadcasting over channels; we actually broadcast over query tokens * support both masks and biases in Transformer2DModel#forward. document behaviour * fix-copies * delete attention_mask docs on the basis I never tested self-attention masking myself. not comfortable explaining it, since I don't actually understand how a self-attn mask can work in its current form: the key length will be different in every ResBlock (we don't downsample the mask when we downsample the image). * review feedback: the standard Unet blocks shouldn't pass temb to attn (only to resnet). remove from KCrossAttnDownBlock2D,KCrossAttnUpBlock2D#forward. * remove encoder_attention_mask param from SimpleCrossAttn{Up,Down}Block2D,UNetMidBlock2DSimpleCrossAttn, and mask-choice in those blocks' #forward, on the basis that they only do one type of attention, so the consumer can pass whichever type of attention_mask is appropriate. * put attention mask padding back to how it was (since the SD use-case it enabled wasn't important, and it breaks the original unclip use-case). disable the test which was added. * fix-copies * style * fix-copies * put encoder_attention_mask param back into Simple block forward interfaces, to ensure consistency of forward interface. * restore passing of emb to KAttentionBlock#forward, on the basis that removal caused test failures. restore also the passing of emb to checkpointed calls to KAttentionBlock#forward. * make simple unet2d blocks use encoder_attention_mask, but only when attention_mask is None. this should fix UnCLIP compatibility. * fix copies
1 parent 4be5115 commit 261ccf6

File tree

7 files changed

+402
-206
lines changed

7 files changed

+402
-206
lines changed

Diff for: models/attention.py

+8-10
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Optional
14+
from typing import Any, Dict, Optional
1515

1616
import torch
1717
import torch.nn.functional as F
@@ -120,13 +120,13 @@ def __init__(
120120

121121
def forward(
122122
self,
123-
hidden_states,
124-
attention_mask=None,
125-
encoder_hidden_states=None,
126-
encoder_attention_mask=None,
127-
timestep=None,
128-
cross_attention_kwargs=None,
129-
class_labels=None,
123+
hidden_states: torch.FloatTensor,
124+
attention_mask: Optional[torch.FloatTensor] = None,
125+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
126+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
127+
timestep: Optional[torch.LongTensor] = None,
128+
cross_attention_kwargs: Dict[str, Any] = None,
129+
class_labels: Optional[torch.LongTensor] = None,
130130
):
131131
# Notice that normalization is always applied before the real computation in the following blocks.
132132
# 1. Self-Attention
@@ -155,8 +155,6 @@ def forward(
155155
norm_hidden_states = (
156156
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
157157
)
158-
# TODO (Birch-San): Here we should prepare the encoder_attention mask correctly
159-
# prepare attention mask here
160158

161159
attn_output = self.attn2(
162160
norm_hidden_states,

Diff for: models/attention_processor.py

+29-4
Original file line numberDiff line numberDiff line change
@@ -380,14 +380,24 @@ def prepare_attention_mask(self, attention_mask, target_length, batch_size=None,
380380
if attention_mask is None:
381381
return attention_mask
382382

383-
if attention_mask.shape[-1] != target_length:
383+
current_length: int = attention_mask.shape[-1]
384+
if current_length > target_length:
385+
# we *could* trim the mask with:
386+
# attention_mask = attention_mask[:,:target_length]
387+
# but this is weird enough that it's more likely to be a mistake than a shortcut
388+
raise ValueError(f"mask's length ({current_length}) exceeds the sequence length ({target_length}).")
389+
elif current_length < target_length:
384390
if attention_mask.device.type == "mps":
385391
# HACK: MPS: Does not support padding by greater than dimension of input tensor.
386392
# Instead, we can manually construct the padding tensor.
387393
padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
388394
padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
389395
attention_mask = torch.cat([attention_mask, padding], dim=2)
390396
else:
397+
# TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
398+
# we want to instead pad by (0, remaining_length), where remaining_length is:
399+
# remaining_length: int = target_length - current_length
400+
# TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
391401
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
392402

393403
if out_dim == 3:
@@ -820,7 +830,13 @@ class XFormersAttnProcessor:
820830
def __init__(self, attention_op: Optional[Callable] = None):
821831
self.attention_op = attention_op
822832

823-
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
833+
def __call__(
834+
self,
835+
attn: Attention,
836+
hidden_states: torch.FloatTensor,
837+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
838+
attention_mask: Optional[torch.FloatTensor] = None,
839+
):
824840
residual = hidden_states
825841

826842
input_ndim = hidden_states.ndim
@@ -829,11 +845,20 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
829845
batch_size, channel, height, width = hidden_states.shape
830846
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
831847

832-
batch_size, sequence_length, _ = (
848+
batch_size, key_tokens, _ = (
833849
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
834850
)
835851

836-
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
852+
attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
853+
if attention_mask is not None:
854+
# expand our mask's singleton query_tokens dimension:
855+
# [batch*heads, 1, key_tokens] ->
856+
# [batch*heads, query_tokens, key_tokens]
857+
# so that it can be added as a bias onto the attention scores that xformers computes:
858+
# [batch*heads, query_tokens, key_tokens]
859+
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
860+
_, query_tokens, _ = hidden_states.shape
861+
attention_mask = attention_mask.expand(-1, query_tokens, -1)
837862

838863
if attn.group_norm is not None:
839864
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

Diff for: models/embeddings.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ def token_drop(self, labels, force_drop_ids=None):
352352
labels = torch.where(drop_ids, self.num_classes, labels)
353353
return labels
354354

355-
def forward(self, labels, force_drop_ids=None):
355+
def forward(self, labels: torch.LongTensor, force_drop_ids=None):
356356
use_dropout = self.dropout_prob > 0
357357
if (self.training and use_dropout) or (force_drop_ids is not None):
358358
labels = self.token_drop(labels, force_drop_ids)

Diff for: models/transformer_2d.py

+40-7
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from dataclasses import dataclass
15-
from typing import Optional
15+
from typing import Any, Dict, Optional
1616

1717
import torch
1818
import torch.nn.functional as F
@@ -213,11 +213,13 @@ def __init__(
213213

214214
def forward(
215215
self,
216-
hidden_states,
217-
encoder_hidden_states=None,
218-
timestep=None,
219-
class_labels=None,
220-
cross_attention_kwargs=None,
216+
hidden_states: torch.Tensor,
217+
encoder_hidden_states: Optional[torch.Tensor] = None,
218+
timestep: Optional[torch.LongTensor] = None,
219+
class_labels: Optional[torch.LongTensor] = None,
220+
cross_attention_kwargs: Dict[str, Any] = None,
221+
attention_mask: Optional[torch.Tensor] = None,
222+
encoder_attention_mask: Optional[torch.Tensor] = None,
221223
return_dict: bool = True,
222224
):
223225
"""
@@ -228,11 +230,17 @@ def forward(
228230
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
229231
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
230232
self-attention.
231-
timestep ( `torch.long`, *optional*):
233+
timestep ( `torch.LongTensor`, *optional*):
232234
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
233235
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
234236
Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels
235237
conditioning.
238+
encoder_attention_mask ( `torch.Tensor`, *optional* ).
239+
Cross-attention mask, applied to encoder_hidden_states. Two formats supported:
240+
Mask `(batch, sequence_length)` True = keep, False = discard. Bias `(batch, 1, sequence_length)` 0
241+
= keep, -10000 = discard.
242+
If ndim == 2: will be interpreted as a mask, then converted into a bias consistent with the format
243+
above. This bias will be added to the cross-attention scores.
236244
return_dict (`bool`, *optional*, defaults to `True`):
237245
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
238246
@@ -241,6 +249,29 @@ def forward(
241249
[`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When
242250
returning a tuple, the first element is the sample tensor.
243251
"""
252+
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
253+
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
254+
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
255+
# expects mask of shape:
256+
# [batch, key_tokens]
257+
# adds singleton query_tokens dimension:
258+
# [batch, 1, key_tokens]
259+
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
260+
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
261+
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
262+
if attention_mask is not None and attention_mask.ndim == 2:
263+
# assume that mask is expressed as:
264+
# (1 = keep, 0 = discard)
265+
# convert mask into a bias that can be added to attention scores:
266+
# (keep = +0, discard = -10000.0)
267+
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
268+
attention_mask = attention_mask.unsqueeze(1)
269+
270+
# convert encoder_attention_mask to a bias the same way we do for attention_mask
271+
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
272+
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
273+
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
274+
244275
# 1. Input
245276
if self.is_input_continuous:
246277
batch, _, height, width = hidden_states.shape
@@ -264,7 +295,9 @@ def forward(
264295
for block in self.transformer_blocks:
265296
hidden_states = block(
266297
hidden_states,
298+
attention_mask=attention_mask,
267299
encoder_hidden_states=encoder_hidden_states,
300+
encoder_attention_mask=encoder_attention_mask,
268301
timestep=timestep,
269302
cross_attention_kwargs=cross_attention_kwargs,
270303
class_labels=class_labels,

0 commit comments

Comments
 (0)