11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
- from typing import Callable , Optional , Union
14
+ from typing import Callable , Optional , Union , Dict , Any
15
15
16
16
import torch
17
17
import torch .nn .functional as F
@@ -198,7 +198,13 @@ def set_processor(self, processor: "AttnProcessor"):
198
198
199
199
self .processor = processor
200
200
201
- def forward (self , hidden_states , encoder_hidden_states = None , attention_mask = None , ** cross_attention_kwargs ):
201
+ def forward (
202
+ self ,
203
+ hidden_states : FloatTensor ,
204
+ encoder_hidden_states : Optional [FloatTensor ] = None ,
205
+ attention_mask : Optional [FloatTensor ] = None ,
206
+ ** cross_attention_kwargs : Dict [str , Any ]
207
+ ):
202
208
# The `CrossAttention` class can call different attention processors / attention functions
203
209
# here we simply pass along all tensors to the selected processor class
204
210
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
@@ -313,8 +319,12 @@ def __call__(
313
319
raise ValueError (f"two attention biases have been supplied: `attention_mask` and `encoder_attention_bias`. expected a maximum of one source of bias." )
314
320
attention_mask = encoder_attention_bias
315
321
# make broadcastable over query tokens
316
- # TODO: consider aligning implementations such that AttnProcessor2_0 and CrossAttnProcessor do unsqueeze
317
- # in the same way/circumstances -- AttnProcessor2_0 does it for `attention_mask` **and** for `encoder_attention_bias`.
322
+ # TODO: see if there's a satisfactory way to unify how the `attention_mask`/`encoder_attention_bias` code paths
323
+ # create this singleton dim. the way AttnProcessor2_0 does it could work.
324
+ # here I'm trying to avoid interfering with the original `attention_mask` code path,
325
+ # by limiting the unsqueeze() to just the `encoder_attention_bias` path, on the basis that
326
+ # `attention_mask` is already working without this change.
327
+ # maybe it's because UNet2DConditionModel#forward unsqueeze()s `attention_mask` earlier.
318
328
attention_mask = attention_mask .unsqueeze (- 2 )
319
329
if attn .cross_attention_norm :
320
330
encoder_hidden_states = attn .norm_cross (encoder_hidden_states )
@@ -453,18 +463,39 @@ class XFormersCrossAttnProcessor:
453
463
def __init__ (self , attention_op : Optional [Callable ] = None ):
454
464
self .attention_op = attention_op
455
465
456
- def __call__ (self , attn : CrossAttention , hidden_states , encoder_hidden_states = None , attention_mask = None ):
457
- batch_size , sequence_length , _ = hidden_states .shape
458
-
459
- attention_mask = attn .prepare_attention_mask (attention_mask , sequence_length , batch_size )
460
-
461
- query = attn .to_q (hidden_states )
462
-
466
+ def __call__ (
467
+ self ,
468
+ attn : CrossAttention ,
469
+ hidden_states : FloatTensor ,
470
+ encoder_hidden_states : Optional [FloatTensor ] = None ,
471
+ attention_mask : Optional [FloatTensor ] = None ,
472
+ encoder_attention_bias : Optional [FloatTensor ] = None ,
473
+ ):
463
474
if encoder_hidden_states is None :
464
475
encoder_hidden_states = hidden_states
465
- elif attn .cross_attention_norm :
466
- encoder_hidden_states = attn .norm_cross (encoder_hidden_states )
476
+ else :
477
+ if encoder_attention_bias is not None :
478
+ if attention_mask is not None :
479
+ # it's not well-defined whether `attention_mask` should be passed to self-attention, cross-attention, neither* or both.
480
+ # if two sources of bias (`attention_mask`, `encoder_attention_bias`) are provided: it's likely to be a mistake.
481
+ raise ValueError (f"two attention biases have been supplied: `attention_mask` and `encoder_attention_bias`. expected a maximum of one source of bias." )
482
+ attention_mask = encoder_attention_bias
483
+
484
+ # TODO: figure out why the original `attention_mask` code path didn't attempt broadcasting over query tokens.
485
+ # it feels like this logic would be needed in that code path too.
467
486
487
+ # make broadcastable over query tokens
488
+ attention_mask = attention_mask .unsqueeze (- 2 )
489
+ _ , query_tokens , _ = hidden_states .shape
490
+ # xformers doesn't broadcast for us, so we expand our singleton dimension manually
491
+ attention_mask = attention_mask .expand (- 1 , query_tokens , - 1 )
492
+ if attn .cross_attention_norm :
493
+ encoder_hidden_states = attn .norm_cross (encoder_hidden_states )
494
+
495
+ batch_size , key_tokens , _ = encoder_hidden_states .shape
496
+ attention_mask = attn .prepare_attention_mask (attention_mask , key_tokens , batch_size )
497
+
498
+ query = attn .to_q (hidden_states )
468
499
key = attn .to_k (encoder_hidden_states )
469
500
value = attn .to_v (encoder_hidden_states )
470
501
@@ -478,10 +509,10 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No
478
509
hidden_states = hidden_states .to (query .dtype )
479
510
hidden_states = attn .batch_to_head_dim (hidden_states )
480
511
481
- # linear proj
482
- hidden_states = attn . to_out [ 0 ]( hidden_states )
483
- # dropout
484
- hidden_states = attn . to_out [ 1 ] (hidden_states )
512
+ linear_proj , dropout = attn . to_out
513
+
514
+ hidden_states = linear_proj ( hidden_states )
515
+ hidden_states = dropout (hidden_states )
485
516
return hidden_states
486
517
487
518
0 commit comments