Skip to content

Commit 4232ad0

Browse files
committed
Cross-attention masks
1 parent 3be4891 commit 4232ad0

File tree

7 files changed

+157
-59
lines changed

7 files changed

+157
-59
lines changed

Diff for: src/diffusers/models/attention.py

+9-11
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import math
15-
from typing import Callable, Optional
15+
from typing import Any, Callable, Dict, Optional
1616

1717
import torch
1818
import torch.nn.functional as F
19-
from torch import nn
19+
from torch import FloatTensor, LongTensor, nn
2020

2121
from ..utils.import_utils import is_xformers_available
2222
from .attention_processor import Attention
@@ -275,13 +275,13 @@ def __init__(
275275

276276
def forward(
277277
self,
278-
hidden_states,
279-
attention_mask=None,
280-
encoder_hidden_states=None,
281-
encoder_attention_mask=None,
282-
timestep=None,
283-
cross_attention_kwargs=None,
284-
class_labels=None,
278+
hidden_states: Optional[FloatTensor],
279+
attention_mask: Optional[FloatTensor] = None,
280+
encoder_hidden_states: Optional[FloatTensor] = None,
281+
encoder_attention_mask: Optional[FloatTensor] = None,
282+
timestep: Optional[LongTensor] = None,
283+
cross_attention_kwargs: Dict[str, Any] = None,
284+
class_labels: Optional[LongTensor] = None,
285285
):
286286
if self.use_ada_layer_norm:
287287
norm_hidden_states = self.norm1(hidden_states, timestep)
@@ -308,8 +308,6 @@ def forward(
308308
norm_hidden_states = (
309309
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
310310
)
311-
# TODO (Birch-San): Here we should prepare the encoder_attention mask correctly
312-
# prepare attention mask here
313311

314312
# 2. Cross-Attention
315313
attn_output = self.attn2(

Diff for: src/diffusers/models/attention_processor.py

+23-6
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import torch
1717
import torch.nn.functional as F
18-
from torch import nn
18+
from torch import FloatTensor, nn
1919

2020
from ..utils import deprecate, logging
2121
from ..utils.import_utils import is_xformers_available
@@ -277,15 +277,22 @@ def prepare_attention_mask(self, attention_mask, target_length, batch_size=None)
277277
if attention_mask is None:
278278
return attention_mask
279279

280-
if attention_mask.shape[-1] != target_length:
280+
current_length: int = attention_mask.shape[-1]
281+
if current_length > target_length:
282+
# we *could* trim the mask with:
283+
# attention_mask = attention_mask[:,:target_length]
284+
# but this is weird enough that it's more likely to be a mistake than a shortcut
285+
raise ValueError(f"mask's length ({current_length}) exceeds the sequence length ({target_length}).")
286+
elif current_length < target_length:
281287
if attention_mask.device.type == "mps":
282288
# HACK: MPS: Does not support padding by greater than dimension of input tensor.
283289
# Instead, we can manually construct the padding tensor.
284290
padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
285291
padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
286292
attention_mask = torch.cat([attention_mask, padding], dim=2)
287293
else:
288-
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
294+
remaining_length: int = target_length - current_length
295+
attention_mask = F.pad(attention_mask, (0, remaining_length), value=0.0)
289296

290297
if attention_mask.shape[0] < batch_size * head_size:
291298
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
@@ -441,12 +448,22 @@ class XFormersAttnProcessor:
441448
def __init__(self, attention_op: Optional[Callable] = None):
442449
self.attention_op = attention_op
443450

444-
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
445-
batch_size, sequence_length, _ = (
451+
def __call__(
452+
self,
453+
attn: Attention,
454+
hidden_states: FloatTensor,
455+
encoder_hidden_states: Optional[FloatTensor] = None,
456+
attention_mask: Optional[FloatTensor] = None,
457+
):
458+
batch_size, key_tokens, _ = (
446459
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
447460
)
448461

449-
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
462+
attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
463+
if attention_mask is not None:
464+
# xformers doesn't broadcast for us, so we expand our singleton dimension manually
465+
_, query_tokens, _ = hidden_states.shape
466+
attention_mask = attention_mask.expand(-1, query_tokens, -1)
450467

451468
query = attn.to_q(hidden_states)
452469

Diff for: src/diffusers/models/embeddings.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import numpy as np
1818
import torch
19-
from torch import nn
19+
from torch import LongTensor, nn
2020

2121

2222
def get_timestep_embedding(
@@ -352,7 +352,7 @@ def token_drop(self, labels, force_drop_ids=None):
352352
labels = torch.where(drop_ids, self.num_classes, labels)
353353
return labels
354354

355-
def forward(self, labels, force_drop_ids=None):
355+
def forward(self, labels: LongTensor, force_drop_ids=None):
356356
use_dropout = self.dropout_prob > 0
357357
if (self.training and use_dropout) or (force_drop_ids is not None):
358358
labels = self.token_drop(labels, force_drop_ids)

Diff for: src/diffusers/models/transformer_2d.py

+16-8
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from dataclasses import dataclass
15-
from typing import Optional
15+
from typing import Any, Dict, Optional
1616

1717
import torch
1818
import torch.nn.functional as F
19-
from torch import nn
19+
from torch import LongTensor, Tensor, nn
2020

2121
from ..configuration_utils import ConfigMixin, register_to_config
2222
from ..models.embeddings import ImagePositionalEmbeddings
@@ -213,22 +213,28 @@ def __init__(
213213

214214
def forward(
215215
self,
216-
hidden_states,
217-
encoder_hidden_states=None,
218-
timestep=None,
219-
class_labels=None,
220-
cross_attention_kwargs=None,
216+
hidden_states: Tensor,
217+
attention_mask: Optional[Tensor] = None,
218+
encoder_hidden_states: Optional[Tensor] = None,
219+
encoder_attention_mask: Optional[Tensor] = None,
220+
timestep: Optional[LongTensor] = None,
221+
class_labels: Optional[LongTensor] = None,
222+
cross_attention_kwargs: Dict[str, Any] = None,
221223
return_dict: bool = True,
222224
):
223225
"""
224226
Args:
225227
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
226228
When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
227229
hidden_states
230+
attention_mask ( `torch.Tensor` of shape (batch size, num latent pixels), *optional* ).
231+
Bias to add to attention scores.
228232
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
229233
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
230234
self-attention.
231-
timestep ( `torch.long`, *optional*):
235+
encoder_attention_mask ( `torch.Tensor` of shape (batch size, num encoder tokens), *optional* ).
236+
Bias to add to cross-attention scores.
237+
timestep ( `torch.LongTensor`, *optional*):
232238
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
233239
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
234240
Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels
@@ -264,7 +270,9 @@ def forward(
264270
for block in self.transformer_blocks:
265271
hidden_states = block(
266272
hidden_states,
273+
attention_mask=attention_mask,
267274
encoder_hidden_states=encoder_hidden_states,
275+
encoder_attention_mask=encoder_attention_mask,
268276
timestep=timestep,
269277
cross_attention_kwargs=cross_attention_kwargs,
270278
class_labels=class_labels,

Diff for: src/diffusers/models/unet_2d_blocks.py

+39-17
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,17 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Optional
14+
from typing import Any, Dict, Optional, Tuple
1515

1616
import numpy as np
1717
import torch
18-
from torch import nn
18+
from torch import FloatTensor, nn
1919

2020
from .attention import AdaGroupNorm, AttentionBlock
2121
from .attention_processor import Attention, AttnAddedKVProcessor
2222
from .dual_transformer_2d import DualTransformer2DModel
2323
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D
24-
from .transformer_2d import Transformer2DModel
24+
from .transformer_2d import Transformer2DModel, Transformer2DModelOutput
2525

2626

2727
def get_down_block(
@@ -533,15 +533,24 @@ def __init__(
533533
self.resnets = nn.ModuleList(resnets)
534534

535535
def forward(
536-
self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None
537-
):
536+
self,
537+
hidden_states: FloatTensor,
538+
temb: Optional[FloatTensor] = None,
539+
encoder_hidden_states: Optional[FloatTensor] = None,
540+
encoder_attention_mask: Optional[FloatTensor] = None,
541+
attention_mask: Optional[FloatTensor] = None,
542+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
543+
) -> FloatTensor:
538544
hidden_states = self.resnets[0](hidden_states, temb)
539545
for attn, resnet in zip(self.attentions, self.resnets[1:]):
540-
hidden_states = attn(
546+
output: Transformer2DModelOutput = attn(
541547
hidden_states,
548+
attention_mask=attention_mask,
542549
encoder_hidden_states=encoder_hidden_states,
550+
encoder_attention_mask=encoder_attention_mask,
543551
cross_attention_kwargs=cross_attention_kwargs,
544-
).sample
552+
)
553+
hidden_states = output.sample
545554
hidden_states = resnet(hidden_states, temb)
546555

547556
return hidden_states
@@ -808,9 +817,14 @@ def __init__(
808817
self.gradient_checkpointing = False
809818

810819
def forward(
811-
self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None
820+
self,
821+
hidden_states: FloatTensor,
822+
temb: Optional[FloatTensor] = None,
823+
encoder_hidden_states: Optional[FloatTensor] = None,
824+
encoder_attention_mask: Optional[FloatTensor] = None,
825+
attention_mask: Optional[FloatTensor] = None,
826+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
812827
):
813-
# TODO(Patrick, William) - attention mask is not used
814828
output_states = ()
815829

816830
for resnet, attn in zip(self.resnets, self.attentions):
@@ -829,14 +843,18 @@ def custom_forward(*inputs):
829843
hidden_states = torch.utils.checkpoint.checkpoint(
830844
create_custom_forward(attn, return_dict=False),
831845
hidden_states,
846+
attention_mask,
832847
encoder_hidden_states,
848+
encoder_attention_mask,
833849
cross_attention_kwargs,
834850
)[0]
835851
else:
836852
hidden_states = resnet(hidden_states, temb)
837853
hidden_states = attn(
838854
hidden_states,
855+
attention_mask=attention_mask,
839856
encoder_hidden_states=encoder_hidden_states,
857+
encoder_attention_mask=encoder_attention_mask,
840858
cross_attention_kwargs=cross_attention_kwargs,
841859
).sample
842860

@@ -1775,15 +1793,15 @@ def __init__(
17751793

17761794
def forward(
17771795
self,
1778-
hidden_states,
1779-
res_hidden_states_tuple,
1780-
temb=None,
1781-
encoder_hidden_states=None,
1782-
cross_attention_kwargs=None,
1783-
upsample_size=None,
1784-
attention_mask=None,
1796+
hidden_states: FloatTensor,
1797+
res_hidden_states_tuple: Tuple[FloatTensor, ...],
1798+
temb: Optional[FloatTensor] = None,
1799+
encoder_hidden_states: Optional[FloatTensor] = None,
1800+
encoder_attention_mask: Optional[FloatTensor] = None,
1801+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1802+
upsample_size: Optional[int] = None,
1803+
attention_mask: Optional[FloatTensor] = None,
17851804
):
1786-
# TODO(Patrick, William) - attention mask is not used
17871805
for resnet, attn in zip(self.resnets, self.attentions):
17881806
# pop res hidden states
17891807
res_hidden_states = res_hidden_states_tuple[-1]
@@ -1805,14 +1823,18 @@ def custom_forward(*inputs):
18051823
hidden_states = torch.utils.checkpoint.checkpoint(
18061824
create_custom_forward(attn, return_dict=False),
18071825
hidden_states,
1826+
attention_mask,
18081827
encoder_hidden_states,
1828+
encoder_attention_mask,
18091829
cross_attention_kwargs,
18101830
)[0]
18111831
else:
18121832
hidden_states = resnet(hidden_states, temb)
18131833
hidden_states = attn(
18141834
hidden_states,
1835+
attention_mask=attention_mask,
18151836
encoder_hidden_states=encoder_hidden_states,
1837+
encoder_attention_mask=encoder_attention_mask,
18161838
cross_attention_kwargs=cross_attention_kwargs,
18171839
).sample
18181840

Diff for: src/diffusers/models/unet_2d_condition.py

+15
Original file line numberDiff line numberDiff line change
@@ -522,6 +522,7 @@ def forward(
522522
sample: torch.FloatTensor,
523523
timestep: Union[torch.Tensor, float, int],
524524
encoder_hidden_states: torch.Tensor,
525+
encoder_attention_mask: Optional[torch.Tensor] = None,
525526
class_labels: Optional[torch.Tensor] = None,
526527
timestep_cond: Optional[torch.Tensor] = None,
527528
attention_mask: Optional[torch.Tensor] = None,
@@ -535,6 +536,10 @@ def forward(
535536
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
536537
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
537538
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
539+
encoder_attention_mask (`torch.Tensor`):
540+
(batch, sequence_length) cross-attention mask (or bias), applied to encoder_hidden_states. if a
541+
BoolTensor is provided: will be turned into a bias, by adding a large negative value. False = hide
542+
token. other tensor types will be used as a bias as-is.
538543
return_dict (`bool`, *optional*, defaults to `True`):
539544
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
540545
cross_attention_kwargs (`dict`, *optional*):
@@ -566,6 +571,13 @@ def forward(
566571
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
567572
attention_mask = attention_mask.unsqueeze(1)
568573

574+
# ensure encoder_attention_mask is a bias, and make it broadcastable over multi-head-attention channels
575+
if encoder_attention_mask is not None:
576+
# if it's a mask: turn it into a bias. otherwise: assume it's already a bias
577+
if encoder_attention_mask.dtype is torch.bool:
578+
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
579+
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
580+
569581
# 0. center input if necessary
570582
if self.config.center_input_sample:
571583
sample = 2 * sample - 1.0
@@ -621,6 +633,7 @@ def forward(
621633
hidden_states=sample,
622634
temb=emb,
623635
encoder_hidden_states=encoder_hidden_states,
636+
encoder_attention_mask=encoder_attention_mask,
624637
attention_mask=attention_mask,
625638
cross_attention_kwargs=cross_attention_kwargs,
626639
)
@@ -646,6 +659,7 @@ def forward(
646659
sample,
647660
emb,
648661
encoder_hidden_states=encoder_hidden_states,
662+
encoder_attention_mask=encoder_attention_mask,
649663
attention_mask=attention_mask,
650664
cross_attention_kwargs=cross_attention_kwargs,
651665
)
@@ -671,6 +685,7 @@ def forward(
671685
temb=emb,
672686
res_hidden_states_tuple=res_samples,
673687
encoder_hidden_states=encoder_hidden_states,
688+
encoder_attention_mask=encoder_attention_mask,
674689
cross_attention_kwargs=cross_attention_kwargs,
675690
upsample_size=upsample_size,
676691
attention_mask=attention_mask,

0 commit comments

Comments
 (0)