17
17
18
18
import torch
19
19
import torch .nn .functional as F
20
- from torch import nn
20
+ from torch import nn , Tensor
21
+ from einops import rearrange , repeat
21
22
22
23
from ..configuration_utils import ConfigMixin , register_to_config
23
24
from ..modeling_utils import ModelMixin
@@ -175,7 +176,7 @@ def __init__(
175
176
self .norm_out = nn .LayerNorm (inner_dim )
176
177
self .out = nn .Linear (inner_dim , self .num_vector_embeds - 1 )
177
178
178
- def forward (self , hidden_states , encoder_hidden_states = None , timestep = None , return_dict : bool = True ):
179
+ def forward (self , hidden_states , encoder_hidden_states = None , timestep = None , return_dict : bool = True , cross_attn_mask : Optional [ torch . Tensor ] = None ):
179
180
"""
180
181
Args:
181
182
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
@@ -213,7 +214,7 @@ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, retu
213
214
214
215
# 2. Blocks
215
216
for block in self .transformer_blocks :
216
- hidden_states = block (hidden_states , context = encoder_hidden_states , timestep = timestep )
217
+ hidden_states = block (hidden_states , context = encoder_hidden_states , timestep = timestep , cross_attn_mask = cross_attn_mask )
217
218
218
219
# 3. Output
219
220
if self .is_input_continuous :
@@ -472,14 +473,14 @@ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_atten
472
473
self .attn1 ._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
473
474
self .attn2 ._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
474
475
475
- def forward (self , hidden_states , context = None , timestep = None ):
476
+ def forward (self , hidden_states , context = None , timestep = None , cross_attn_mask : Optional [ torch . Tensor ] = None ):
476
477
# 1. Self-Attention
477
478
norm_hidden_states = (
478
479
self .norm1 (hidden_states , timestep ) if self .use_ada_layer_norm else self .norm1 (hidden_states )
479
480
)
480
481
481
482
if self .only_cross_attention :
482
- hidden_states = self .attn1 (norm_hidden_states , context ) + hidden_states
483
+ hidden_states = self .attn1 (norm_hidden_states , context , cross_attn_mask = cross_attn_mask ) + hidden_states
483
484
else :
484
485
hidden_states = self .attn1 (norm_hidden_states ) + hidden_states
485
486
@@ -488,7 +489,7 @@ def forward(self, hidden_states, context=None, timestep=None):
488
489
norm_hidden_states = (
489
490
self .norm2 (hidden_states , timestep ) if self .use_ada_layer_norm else self .norm2 (hidden_states )
490
491
)
491
- hidden_states = self .attn2 (norm_hidden_states , context = context ) + hidden_states
492
+ hidden_states = self .attn2 (norm_hidden_states , context = context , cross_attn_mask = cross_attn_mask ) + hidden_states
492
493
493
494
# 3. Feed-forward
494
495
hidden_states = self .ff (self .norm3 (hidden_states )) + hidden_states
@@ -563,7 +564,7 @@ def set_attention_slice(self, slice_size):
563
564
564
565
self ._slice_size = slice_size
565
566
566
- def forward (self , hidden_states , context = None , mask = None ):
567
+ def forward (self , hidden_states , context = None , mask = None , cross_attn_mask : Optional [ Tensor ] = None ):
567
568
batch_size , sequence_length , _ = hidden_states .shape
568
569
569
570
query = self .to_q (hidden_states )
@@ -577,26 +578,29 @@ def forward(self, hidden_states, context=None, mask=None):
577
578
key = self .reshape_heads_to_batch_dim (key )
578
579
value = self .reshape_heads_to_batch_dim (value )
579
580
580
- # TODO(PVP) - mask is currently never used. Remember to re-implement when used
581
+ # TODO AKB: `mask` param remains unimplemented. the parameter remains reserved
582
+ # in case we should ever need a self-attention mask
583
+ # (e.g. a pixel/latent-space mask to avoid self-attending to padding pixels, such as pillarboxing/letterboxing).
581
584
582
585
# attention, what we cannot get enough of
583
586
if self ._use_memory_efficient_attention_xformers :
587
+ assert cross_attn_mask is None , "cross-attention masking not implemented for xformers attention"
584
588
hidden_states = self ._memory_efficient_attention_xformers (query , key , value )
585
589
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
586
590
hidden_states = hidden_states .to (query .dtype )
587
591
else :
588
592
if self ._slice_size is None or query .shape [0 ] // self ._slice_size == 1 :
589
- hidden_states = self ._attention (query , key , value )
593
+ hidden_states = self ._attention (query , key , value , cross_attn_mask = cross_attn_mask )
590
594
else :
591
- hidden_states = self ._sliced_attention (query , key , value , sequence_length , dim )
595
+ hidden_states = self ._sliced_attention (query , key , value , sequence_length , dim , cross_attn_mask = cross_attn_mask )
592
596
593
597
# linear proj
594
598
hidden_states = self .to_out [0 ](hidden_states )
595
599
# dropout
596
600
hidden_states = self .to_out [1 ](hidden_states )
597
601
return hidden_states
598
602
599
- def _attention (self , query , key , value ):
603
+ def _attention (self , query , key , value , cross_attn_mask : Optional [ Tensor ] = None ):
600
604
if self .upcast_attention :
601
605
query = query .float ()
602
606
key = key .float ()
@@ -608,6 +612,12 @@ def _attention(self, query, key, value):
608
612
beta = 0 ,
609
613
alpha = self .scale ,
610
614
)
615
+ if cross_attn_mask is not None :
616
+ cross_attn_mask = rearrange (cross_attn_mask , 'b ... -> b (...)' )
617
+ max_neg_value = - torch .finfo (attention_scores .dtype ).max
618
+ cross_attn_mask = repeat (cross_attn_mask , 'b j -> (b h) () j' , h = self .heads )
619
+ attention_scores .masked_fill_ (~ cross_attn_mask , max_neg_value )
620
+ del cross_attn_mask
611
621
attention_probs = attention_scores .softmax (dim = - 1 )
612
622
613
623
# cast back to the original dtype
@@ -620,11 +630,15 @@ def _attention(self, query, key, value):
620
630
hidden_states = self .reshape_batch_dim_to_heads (hidden_states )
621
631
return hidden_states
622
632
623
- def _sliced_attention (self , query , key , value , sequence_length , dim ):
633
+ def _sliced_attention (self , query , key , value , sequence_length , dim , cross_attn_mask : Optional [ Tensor ] = None ):
624
634
batch_size_attention = query .shape [0 ]
625
635
hidden_states = torch .zeros (
626
636
(batch_size_attention , sequence_length , dim // self .heads ), device = query .device , dtype = query .dtype
627
637
)
638
+ if cross_attn_mask is not None :
639
+ cross_attn_mask = rearrange (cross_attn_mask , 'b ... -> b (...)' )
640
+ max_neg_value = - torch .finfo (query .dtype ).max
641
+ cross_attn_mask = repeat (cross_attn_mask , 'b j -> (b h) () j' , h = self .heads )
628
642
slice_size = self ._slice_size if self ._slice_size is not None else hidden_states .shape [0 ]
629
643
for i in range (hidden_states .shape [0 ] // slice_size ):
630
644
start_idx = i * slice_size
@@ -644,6 +658,10 @@ def _sliced_attention(self, query, key, value, sequence_length, dim):
644
658
beta = 0 ,
645
659
alpha = self .scale ,
646
660
)
661
+ if cross_attn_mask is not None :
662
+ cross_attn_mask_slice = cross_attn_mask [start_idx :end_idx ]
663
+ attn_slice .masked_fill_ (~ cross_attn_mask_slice , max_neg_value )
664
+ del cross_attn_mask_slice
647
665
attn_slice = attn_slice .softmax (dim = - 1 )
648
666
649
667
# cast back to the original dtype
0 commit comments