@@ -622,31 +622,53 @@ class SlicedAttnProcessor:
622
622
def __init__ (self , slice_size ):
623
623
self .slice_size = slice_size
624
624
625
- def __call__ (self , attn : CrossAttention , hidden_states , encoder_hidden_states = None , attention_mask = None ):
626
- batch_size , sequence_length , _ = hidden_states .shape
625
+ def __call__ (
626
+ self ,
627
+ attn : CrossAttention ,
628
+ hidden_states : FloatTensor ,
629
+ encoder_hidden_states : Optional [FloatTensor ] = None ,
630
+ attention_mask : Optional [FloatTensor ] = None ,
631
+ encoder_attention_bias : Optional [FloatTensor ] = None ,
632
+ ):
633
+ if encoder_hidden_states is None :
634
+ encoder_hidden_states = hidden_states
635
+ else :
636
+ if encoder_attention_bias is not None :
637
+ if attention_mask is not None :
638
+ # it's not well-defined whether `attention_mask` should be passed to self-attention, cross-attention, neither* or both.
639
+ # if two sources of bias (`attention_mask`, `encoder_attention_bias`) are provided: it's likely to be a mistake.
640
+ raise ValueError (f"two attention biases have been supplied: `attention_mask` and `encoder_attention_bias`. expected a maximum of one source of bias." )
641
+ attention_mask = encoder_attention_bias
642
+ # make broadcastable over query tokens
643
+ # TODO: see if there's a satisfactory way to unify how the `attention_mask`/`encoder_attention_bias` code paths
644
+ # create this singleton dim. the way AttnProcessor2_0 does it could work.
645
+ # here I'm trying to avoid interfering with the original `attention_mask` code path,
646
+ # by limiting the unsqueeze() to just the `encoder_attention_bias` path, on the basis that
647
+ # `attention_mask` is already working without this change.
648
+ # maybe it's because UNet2DConditionModel#forward unsqueeze()s `attention_mask` earlier.
649
+ attention_mask = attention_mask .unsqueeze (- 2 )
650
+ if attn .cross_attention_norm :
651
+ encoder_hidden_states = attn .norm_cross (encoder_hidden_states )
627
652
628
- attention_mask = attn .prepare_attention_mask (attention_mask , sequence_length , batch_size )
653
+ batch_size , key_tokens , _ = encoder_hidden_states .shape
654
+ attention_mask = attn .prepare_attention_mask (attention_mask , key_tokens , batch_size )
629
655
630
656
query = attn .to_q (hidden_states )
631
- dim = query .shape [- 1 ]
632
657
query = attn .head_to_batch_dim (query )
633
658
634
- if encoder_hidden_states is None :
635
- encoder_hidden_states = hidden_states
636
- elif attn .cross_attention_norm :
637
- encoder_hidden_states = attn .norm_cross (encoder_hidden_states )
638
-
639
659
key = attn .to_k (encoder_hidden_states )
640
660
value = attn .to_v (encoder_hidden_states )
641
661
key = attn .head_to_batch_dim (key )
642
662
value = attn .head_to_batch_dim (value )
643
663
644
- batch_size_attention = query .shape [0 ]
664
+ batch_x_heads , query_tokens , _ = query .shape
665
+ inner_dim = attn .to_q .out_features
666
+ channels_per_head = inner_dim // attn .heads
645
667
hidden_states = torch .zeros (
646
- (batch_size_attention , sequence_length , dim // attn . heads ), device = query .device , dtype = query .dtype
668
+ (batch_x_heads , query_tokens , channels_per_head ), device = query .device , dtype = query .dtype
647
669
)
648
670
649
- for i in range (hidden_states . shape [ 0 ] // self .slice_size ):
671
+ for i in range (batch_x_heads // self .slice_size ):
650
672
start_idx = i * self .slice_size
651
673
end_idx = (i + 1 ) * self .slice_size
652
674
@@ -662,10 +684,10 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No
662
684
663
685
hidden_states = attn .batch_to_head_dim (hidden_states )
664
686
665
- # linear proj
666
- hidden_states = attn . to_out [ 0 ]( hidden_states )
667
- # dropout
668
- hidden_states = attn . to_out [ 1 ] (hidden_states )
687
+ linear_proj , dropout = attn . to_out
688
+
689
+ hidden_states = linear_proj ( hidden_states )
690
+ hidden_states = dropout (hidden_states )
669
691
670
692
return hidden_states
671
693
0 commit comments