Skip to content

Commit cbb4c02

Browse files
committed
implement cross-attention mask
1 parent fc94c60 commit cbb4c02

File tree

3 files changed

+41
-18
lines changed

3 files changed

+41
-18
lines changed

Diff for: src/diffusers/models/attention.py

+30-12
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717

1818
import torch
1919
import torch.nn.functional as F
20-
from torch import nn
20+
from torch import nn, Tensor
21+
from einops import rearrange, repeat
2122

2223
from ..configuration_utils import ConfigMixin, register_to_config
2324
from ..modeling_utils import ModelMixin
@@ -175,7 +176,7 @@ def __init__(
175176
self.norm_out = nn.LayerNorm(inner_dim)
176177
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
177178

178-
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
179+
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True, cross_attn_mask: Optional[torch.Tensor] = None):
179180
"""
180181
Args:
181182
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
@@ -213,7 +214,7 @@ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, retu
213214

214215
# 2. Blocks
215216
for block in self.transformer_blocks:
216-
hidden_states = block(hidden_states, context=encoder_hidden_states, timestep=timestep)
217+
hidden_states = block(hidden_states, context=encoder_hidden_states, timestep=timestep, cross_attn_mask=cross_attn_mask)
217218

218219
# 3. Output
219220
if self.is_input_continuous:
@@ -472,14 +473,14 @@ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_atten
472473
self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
473474
self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
474475

475-
def forward(self, hidden_states, context=None, timestep=None):
476+
def forward(self, hidden_states, context=None, timestep=None, cross_attn_mask: Optional[torch.Tensor] = None):
476477
# 1. Self-Attention
477478
norm_hidden_states = (
478479
self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
479480
)
480481

481482
if self.only_cross_attention:
482-
hidden_states = self.attn1(norm_hidden_states, context) + hidden_states
483+
hidden_states = self.attn1(norm_hidden_states, context, cross_attn_mask=cross_attn_mask) + hidden_states
483484
else:
484485
hidden_states = self.attn1(norm_hidden_states) + hidden_states
485486

@@ -488,7 +489,7 @@ def forward(self, hidden_states, context=None, timestep=None):
488489
norm_hidden_states = (
489490
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
490491
)
491-
hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states
492+
hidden_states = self.attn2(norm_hidden_states, context=context, cross_attn_mask=cross_attn_mask) + hidden_states
492493

493494
# 3. Feed-forward
494495
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
@@ -563,7 +564,7 @@ def set_attention_slice(self, slice_size):
563564

564565
self._slice_size = slice_size
565566

566-
def forward(self, hidden_states, context=None, mask=None):
567+
def forward(self, hidden_states, context=None, mask=None, cross_attn_mask:Optional[Tensor]=None):
567568
batch_size, sequence_length, _ = hidden_states.shape
568569

569570
query = self.to_q(hidden_states)
@@ -577,26 +578,29 @@ def forward(self, hidden_states, context=None, mask=None):
577578
key = self.reshape_heads_to_batch_dim(key)
578579
value = self.reshape_heads_to_batch_dim(value)
579580

580-
# TODO(PVP) - mask is currently never used. Remember to re-implement when used
581+
# TODO AKB: `mask` param remains unimplemented. the parameter remains reserved
582+
# in case we should ever need a self-attention mask
583+
# (e.g. a pixel/latent-space mask to avoid self-attending to padding pixels, such as pillarboxing/letterboxing).
581584

582585
# attention, what we cannot get enough of
583586
if self._use_memory_efficient_attention_xformers:
587+
assert cross_attn_mask is None, "cross-attention masking not implemented for xformers attention"
584588
hidden_states = self._memory_efficient_attention_xformers(query, key, value)
585589
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
586590
hidden_states = hidden_states.to(query.dtype)
587591
else:
588592
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
589-
hidden_states = self._attention(query, key, value)
593+
hidden_states = self._attention(query, key, value, cross_attn_mask=cross_attn_mask)
590594
else:
591-
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim)
595+
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, cross_attn_mask=cross_attn_mask)
592596

593597
# linear proj
594598
hidden_states = self.to_out[0](hidden_states)
595599
# dropout
596600
hidden_states = self.to_out[1](hidden_states)
597601
return hidden_states
598602

599-
def _attention(self, query, key, value):
603+
def _attention(self, query, key, value, cross_attn_mask:Optional[Tensor]=None):
600604
if self.upcast_attention:
601605
query = query.float()
602606
key = key.float()
@@ -608,6 +612,12 @@ def _attention(self, query, key, value):
608612
beta=0,
609613
alpha=self.scale,
610614
)
615+
if cross_attn_mask is not None:
616+
cross_attn_mask = rearrange(cross_attn_mask, 'b ... -> b (...)')
617+
max_neg_value = -torch.finfo(attention_scores.dtype).max
618+
cross_attn_mask = repeat(cross_attn_mask, 'b j -> (b h) () j', h=self.heads)
619+
attention_scores.masked_fill_(~cross_attn_mask, max_neg_value)
620+
del cross_attn_mask
611621
attention_probs = attention_scores.softmax(dim=-1)
612622

613623
# cast back to the original dtype
@@ -620,11 +630,15 @@ def _attention(self, query, key, value):
620630
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
621631
return hidden_states
622632

623-
def _sliced_attention(self, query, key, value, sequence_length, dim):
633+
def _sliced_attention(self, query, key, value, sequence_length, dim, cross_attn_mask:Optional[Tensor]=None):
624634
batch_size_attention = query.shape[0]
625635
hidden_states = torch.zeros(
626636
(batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
627637
)
638+
if cross_attn_mask is not None:
639+
cross_attn_mask = rearrange(cross_attn_mask, 'b ... -> b (...)')
640+
max_neg_value = -torch.finfo(query.dtype).max
641+
cross_attn_mask = repeat(cross_attn_mask, 'b j -> (b h) () j', h=self.heads)
628642
slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
629643
for i in range(hidden_states.shape[0] // slice_size):
630644
start_idx = i * slice_size
@@ -644,6 +658,10 @@ def _sliced_attention(self, query, key, value, sequence_length, dim):
644658
beta=0,
645659
alpha=self.scale,
646660
)
661+
if cross_attn_mask is not None:
662+
cross_attn_mask_slice = cross_attn_mask[start_idx:end_idx]
663+
attn_slice.masked_fill_(~cross_attn_mask_slice, max_neg_value)
664+
del cross_attn_mask_slice
647665
attn_slice = attn_slice.softmax(dim=-1)
648666

649667
# cast back to the original dtype

Diff for: src/diffusers/models/unet_2d_blocks.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import numpy as np
1515
import torch
1616
from torch import nn
17+
from typing import Optional
1718

1819
from .attention import AttentionBlock, DualTransformer2DModel, Transformer2DModel
1920
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D
@@ -408,10 +409,10 @@ def __init__(
408409
self.attentions = nn.ModuleList(attentions)
409410
self.resnets = nn.ModuleList(resnets)
410411

411-
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
412+
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, cross_attn_mask: Optional[torch.Tensor] = None):
412413
hidden_states = self.resnets[0](hidden_states, temb)
413414
for attn, resnet in zip(self.attentions, self.resnets[1:]):
414-
hidden_states = attn(hidden_states, encoder_hidden_states).sample
415+
hidden_states = attn(hidden_states, encoder_hidden_states, cross_attn_mask=cross_attn_mask).sample
415416
hidden_states = resnet(hidden_states, temb)
416417

417418
return hidden_states
@@ -588,7 +589,7 @@ def __init__(
588589

589590
self.gradient_checkpointing = False
590591

591-
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
592+
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, cross_attn_mask: Optional[torch.Tensor] = None):
592593
output_states = ()
593594

594595
for resnet, attn in zip(self.resnets, self.attentions):
@@ -609,7 +610,7 @@ def custom_forward(*inputs):
609610
)[0]
610611
else:
611612
hidden_states = resnet(hidden_states, temb)
612-
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
613+
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attn_mask=cross_attn_mask).sample
613614

614615
output_states += (hidden_states,)
615616

@@ -1176,6 +1177,7 @@ def forward(
11761177
temb=None,
11771178
encoder_hidden_states=None,
11781179
upsample_size=None,
1180+
cross_attn_mask: Optional[torch.Tensor] = None,
11791181
):
11801182
for resnet, attn in zip(self.resnets, self.attentions):
11811183
# pop res hidden states
@@ -1200,7 +1202,7 @@ def custom_forward(*inputs):
12001202
)[0]
12011203
else:
12021204
hidden_states = resnet(hidden_states, temb)
1203-
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
1205+
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attn_mask=cross_attn_mask).sample
12041206

12051207
if self.upsamplers is not None:
12061208
for upsampler in self.upsamplers:

Diff for: src/diffusers/models/unet_2d_condition.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,7 @@ def forward(
308308
encoder_hidden_states: torch.Tensor,
309309
class_labels: Optional[torch.Tensor] = None,
310310
return_dict: bool = True,
311+
cross_attn_mask: Optional[torch.Tensor] = None,
311312
) -> Union[UNet2DConditionOutput, Tuple]:
312313
r"""
313314
Args:
@@ -382,14 +383,15 @@ def forward(
382383
hidden_states=sample,
383384
temb=emb,
384385
encoder_hidden_states=encoder_hidden_states,
386+
cross_attn_mask=cross_attn_mask,
385387
)
386388
else:
387389
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
388390

389391
down_block_res_samples += res_samples
390392

391393
# 4. mid
392-
sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
394+
sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states, cross_attn_mask=cross_attn_mask)
393395

394396
# 5. up
395397
for i, upsample_block in enumerate(self.up_blocks):
@@ -410,6 +412,7 @@ def forward(
410412
res_hidden_states_tuple=res_samples,
411413
encoder_hidden_states=encoder_hidden_states,
412414
upsample_size=upsample_size,
415+
cross_attn_mask=cross_attn_mask,
413416
)
414417
else:
415418
sample = upsample_block(

0 commit comments

Comments
 (0)