Skip to content

Commit e3a93e9

Browse files
committed
cross-attention mask
1 parent f55dbf3 commit e3a93e9

File tree

4 files changed

+33
-6
lines changed

4 files changed

+33
-6
lines changed

src/diffusers/models/attention.py

+5
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ def forward(
181181
hidden_states,
182182
encoder_hidden_states=None,
183183
timestep=None,
184+
cross_attn_mask: Optional[torch.Tensor] = None,
184185
cross_attention_kwargs=None,
185186
return_dict: bool = True,
186187
):
@@ -225,6 +226,7 @@ def forward(
225226
hidden_states,
226227
encoder_hidden_states=encoder_hidden_states,
227228
timestep=timestep,
229+
cross_attn_mask=cross_attn_mask,
228230
cross_attention_kwargs=cross_attention_kwargs,
229231
)
230232

@@ -466,6 +468,7 @@ def forward(
466468
encoder_hidden_states=None,
467469
timestep=None,
468470
attention_mask=None,
471+
cross_attn_mask: Optional[torch.Tensor] = None,
469472
cross_attention_kwargs=None,
470473
):
471474
# 1. Self-Attention
@@ -477,6 +480,7 @@ def forward(
477480
norm_hidden_states,
478481
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
479482
attention_mask=attention_mask,
483+
cross_attn_mask=cross_attn_mask if self.only_cross_attention else None,
480484
**cross_attention_kwargs,
481485
)
482486
hidden_states = attn_output + hidden_states
@@ -490,6 +494,7 @@ def forward(
490494
norm_hidden_states,
491495
encoder_hidden_states=encoder_hidden_states,
492496
attention_mask=attention_mask,
497+
cross_attn_mask=cross_attn_mask,
493498
**cross_attention_kwargs,
494499
)
495500
hidden_states = attn_output + hidden_states

src/diffusers/models/cross_attention.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def set_subquadratic_attention(
173173
def set_processor(self, processor: "AttnProcessor"):
174174
self.processor = processor
175175

176-
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
176+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, cross_attn_mask: Optional[torch.Tensor] = None, **cross_attention_kwargs):
177177
# The `CrossAttention` class can call different attention processors / attention functions
178178
# here we simply pass along all tensors to the selected processor class
179179
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
@@ -182,6 +182,7 @@ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None
182182
hidden_states,
183183
encoder_hidden_states=encoder_hidden_states,
184184
attention_mask=attention_mask,
185+
cross_attn_mask=cross_attn_mask,
185186
**cross_attention_kwargs,
186187
)
187188

@@ -199,11 +200,20 @@ def head_to_batch_dim(self, tensor):
199200
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
200201
return tensor
201202

202-
def get_attention_scores(self, query, key, attention_mask=None):
203+
def get_attention_scores(self, query, key, attention_mask=None, cross_attn_mask: Optional[torch.Tensor] = None):
203204
dtype = query.dtype
204205
if self.upcast_attention:
205206
query = query.float()
206207
key = key.float()
208+
209+
# haven't defined what to do if both are present
210+
if attention_mask is not None:
211+
assert cross_attn_mask is None
212+
if cross_attn_mask is not None:
213+
assert attention_mask is None
214+
device = cross_attn_mask.device
215+
cross_attn_mask = cross_attn_mask.to('cpu' if device.type == 'mps' else device).repeat_interleave(self.heads, dim=0).to(device).unsqueeze(1)
216+
attention_mask = cross_attn_mask
207217

208218
beta = 0 if attention_mask is None else 1
209219
add = torch.empty(
@@ -242,7 +252,7 @@ def prepare_attention_mask(self, attention_mask, target_length):
242252

243253

244254
class CrossAttnProcessor:
245-
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
255+
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, cross_attn_mask: Optional[torch.Tensor] = None):
246256
batch_size, sequence_length, _ = hidden_states.shape
247257
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
248258

@@ -255,7 +265,7 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No
255265
key = attn.head_to_batch_dim(key)
256266
value = attn.head_to_batch_dim(value)
257267

258-
attention_probs = attn.get_attention_scores(query, key, attention_mask)
268+
attention_probs = attn.get_attention_scores(query, key, attention_mask, cross_attn_mask=cross_attn_mask)
259269
hidden_states = torch.bmm(attention_probs, value)
260270
hidden_states = attn.batch_to_head_dim(hidden_states)
261271

src/diffusers/models/unet_2d_blocks.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import numpy as np
1515
import torch
1616
from torch import nn
17+
from typing import Optional
1718

1819
from .attention import AttentionBlock, DualTransformer2DModel, Transformer2DModel
1920
from .cross_attention import CrossAttention, CrossAttnAddedKVProcessor
@@ -483,13 +484,14 @@ def __init__(
483484
self.resnets = nn.ModuleList(resnets)
484485

485486
def forward(
486-
self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None
487+
self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attn_mask: Optional[torch.Tensor] = None, cross_attention_kwargs=None
487488
):
488489
hidden_states = self.resnets[0](hidden_states, temb)
489490
for attn, resnet in zip(self.attentions, self.resnets[1:]):
490491
hidden_states = attn(
491492
hidden_states,
492493
encoder_hidden_states=encoder_hidden_states,
494+
cross_attn_mask=cross_attn_mask,
493495
cross_attention_kwargs=cross_attention_kwargs,
494496
).sample
495497
hidden_states = resnet(hidden_states, temb)
@@ -758,7 +760,7 @@ def __init__(
758760
self.gradient_checkpointing = False
759761

760762
def forward(
761-
self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None
763+
self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attn_mask: Optional[torch.Tensor] = None, cross_attention_kwargs=None
762764
):
763765
# TODO(Patrick, William) - attention mask is not used
764766
output_states = ()
@@ -787,6 +789,7 @@ def custom_forward(*inputs):
787789
hidden_states = attn(
788790
hidden_states,
789791
encoder_hidden_states=encoder_hidden_states,
792+
cross_attn_mask=cross_attn_mask,
790793
cross_attention_kwargs=cross_attention_kwargs,
791794
).sample
792795

@@ -1549,6 +1552,7 @@ def forward(
15491552
cross_attention_kwargs=None,
15501553
upsample_size=None,
15511554
attention_mask=None,
1555+
cross_attn_mask: Optional[torch.Tensor] = None,
15521556
):
15531557
# TODO(Patrick, William) - attention mask is not used
15541558
for resnet, attn in zip(self.resnets, self.attentions):
@@ -1580,6 +1584,7 @@ def custom_forward(*inputs):
15801584
hidden_states = attn(
15811585
hidden_states,
15821586
encoder_hidden_states=encoder_hidden_states,
1587+
cross_attn_mask=cross_attn_mask,
15831588
cross_attention_kwargs=cross_attention_kwargs,
15841589
).sample
15851590

src/diffusers/models/unet_2d_condition.py

+7
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,7 @@ def forward(
354354
encoder_hidden_states: torch.Tensor,
355355
class_labels: Optional[torch.Tensor] = None,
356356
attention_mask: Optional[torch.Tensor] = None,
357+
cross_attn_mask: Optional[torch.Tensor] = None,
357358
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
358359
return_dict: bool = True,
359360
) -> Union[UNet2DConditionOutput, Tuple]:
@@ -384,6 +385,9 @@ def forward(
384385
logger.info("Forward upsample size to force interpolation output size.")
385386
forward_upsample_size = True
386387

388+
if cross_attn_mask is not None:
389+
cross_attn_mask = (1 - cross_attn_mask.to(sample.dtype)) * -torch.finfo(sample.dtype).max
390+
387391
# prepare attention_mask
388392
if attention_mask is not None:
389393
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
@@ -440,6 +444,7 @@ def forward(
440444
temb=emb,
441445
encoder_hidden_states=encoder_hidden_states,
442446
attention_mask=attention_mask,
447+
cross_attn_mask=cross_attn_mask,
443448
cross_attention_kwargs=cross_attention_kwargs,
444449
)
445450
else:
@@ -453,6 +458,7 @@ def forward(
453458
emb,
454459
encoder_hidden_states=encoder_hidden_states,
455460
attention_mask=attention_mask,
461+
cross_attn_mask=cross_attn_mask,
456462
cross_attention_kwargs=cross_attention_kwargs,
457463
)
458464

@@ -477,6 +483,7 @@ def forward(
477483
cross_attention_kwargs=cross_attention_kwargs,
478484
upsample_size=upsample_size,
479485
attention_mask=attention_mask,
486+
cross_attn_mask=cross_attn_mask,
480487
)
481488
else:
482489
sample = upsample_block(

0 commit comments

Comments
 (0)