@@ -1707,17 +1707,28 @@ def forward(
1707
1707
encoder_hidden_states : Optional [torch .FloatTensor ] = None ,
1708
1708
attention_mask : Optional [torch .FloatTensor ] = None ,
1709
1709
cross_attention_kwargs : Optional [Dict [str , Any ]] = None ,
1710
- # parameter exists only for interface-compatibility with other blocks. prefer attention_mask
1711
1710
encoder_attention_mask : Optional [torch .FloatTensor ] = None ,
1712
1711
):
1713
1712
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
1713
+
1714
+ if attention_mask is None :
1715
+ # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
1716
+ mask = None if encoder_hidden_states is None else encoder_attention_mask
1717
+ else :
1718
+ # when attention_mask is defined: we don't even check for encoder_attention_mask.
1719
+ # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks.
1720
+ # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask.
1721
+ # then we can simplify this whole if/else block to:
1722
+ # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask
1723
+ mask = attention_mask
1724
+
1714
1725
hidden_states = self .resnets [0 ](hidden_states , temb )
1715
1726
for attn , resnet in zip (self .attentions , self .resnets [1 :]):
1716
1727
# attn
1717
1728
hidden_states = attn (
1718
1729
hidden_states ,
1719
1730
encoder_hidden_states = encoder_hidden_states ,
1720
- attention_mask = attention_mask ,
1731
+ attention_mask = mask ,
1721
1732
** cross_attention_kwargs ,
1722
1733
)
1723
1734
0 commit comments