Skip to content

Commit 8c323c5

Browse files
committed
fix copies
1 parent 83762b2 commit 8c323c5

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -1707,17 +1707,28 @@ def forward(
17071707
encoder_hidden_states: Optional[torch.FloatTensor] = None,
17081708
attention_mask: Optional[torch.FloatTensor] = None,
17091709
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1710-
# parameter exists only for interface-compatibility with other blocks. prefer attention_mask
17111710
encoder_attention_mask: Optional[torch.FloatTensor] = None,
17121711
):
17131712
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
1713+
1714+
if attention_mask is None:
1715+
# if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
1716+
mask = None if encoder_hidden_states is None else encoder_attention_mask
1717+
else:
1718+
# when attention_mask is defined: we don't even check for encoder_attention_mask.
1719+
# this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks.
1720+
# TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask.
1721+
# then we can simplify this whole if/else block to:
1722+
# mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask
1723+
mask = attention_mask
1724+
17141725
hidden_states = self.resnets[0](hidden_states, temb)
17151726
for attn, resnet in zip(self.attentions, self.resnets[1:]):
17161727
# attn
17171728
hidden_states = attn(
17181729
hidden_states,
17191730
encoder_hidden_states=encoder_hidden_states,
1720-
attention_mask=attention_mask,
1731+
attention_mask=mask,
17211732
**cross_attention_kwargs,
17221733
)
17231734

0 commit comments

Comments
 (0)