Skip to content

Commit e35b7df

Browse files
committed
fix-copies
1 parent e47e1c0 commit e35b7df

File tree

1 file changed

+90
-55
lines changed

1 file changed

+90
-55
lines changed

Diff for: src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py

+90-55
Original file line numberDiff line numberDiff line change
@@ -721,13 +721,18 @@ def forward(
721721
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
722722
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
723723
mid_block_additional_residual: Optional[torch.Tensor] = None,
724+
encoder_attention_mask: Optional[torch.Tensor] = None,
724725
return_dict: bool = True,
725726
) -> Union[UNet2DConditionOutput, Tuple]:
726727
r"""
727728
Args:
728729
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
729730
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
730731
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
732+
encoder_attention_mask (`torch.Tensor`):
733+
(batch, sequence_length) cross-attention mask, applied to encoder_hidden_states. True = keep, False =
734+
discard. Mask will be converted into a bias, which adds large negative values to attention scores
735+
corresponding to "discard" tokens.
731736
return_dict (`bool`, *optional*, defaults to `True`):
732737
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
733738
cross_attention_kwargs (`dict`, *optional*):
@@ -754,11 +759,27 @@ def forward(
754759
logger.info("Forward upsample size to force interpolation output size.")
755760
forward_upsample_size = True
756761

757-
# prepare attention_mask
762+
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension
763+
# expects mask of shape:
764+
# [batch, key_tokens]
765+
# adds singleton query_tokens dimension:
766+
# [batch, 1, key_tokens]
767+
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
768+
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
769+
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
758770
if attention_mask is not None:
771+
# assume that mask is expressed as:
772+
# (1 = keep, 0 = discard)
773+
# convert mask into a bias that can be added to attention scores:
774+
# (keep = +0, discard = -10000.0)
759775
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
760776
attention_mask = attention_mask.unsqueeze(1)
761777

778+
# convert encoder_attention_mask to a bias the same way we do for attention_mask
779+
if encoder_attention_mask is not None:
780+
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
781+
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
782+
762783
# 0. center input if necessary
763784
if self.config.center_input_sample:
764785
sample = 2 * sample - 1.0
@@ -830,6 +851,7 @@ def forward(
830851
encoder_hidden_states=encoder_hidden_states,
831852
attention_mask=attention_mask,
832853
cross_attention_kwargs=cross_attention_kwargs,
854+
encoder_attention_mask=encoder_attention_mask,
833855
)
834856
else:
835857
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
@@ -855,6 +877,7 @@ def forward(
855877
encoder_hidden_states=encoder_hidden_states,
856878
attention_mask=attention_mask,
857879
cross_attention_kwargs=cross_attention_kwargs,
880+
encoder_attention_mask=encoder_attention_mask,
858881
)
859882

860883
if mid_block_additional_residual is not None:
@@ -881,6 +904,7 @@ def forward(
881904
cross_attention_kwargs=cross_attention_kwargs,
882905
upsample_size=upsample_size,
883906
attention_mask=attention_mask,
907+
encoder_attention_mask=encoder_attention_mask,
884908
)
885909
else:
886910
sample = upsample_block(
@@ -1188,9 +1212,14 @@ def __init__(
11881212
self.gradient_checkpointing = False
11891213

11901214
def forward(
1191-
self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None
1215+
self,
1216+
hidden_states: torch.FloatTensor,
1217+
temb: Optional[torch.FloatTensor] = None,
1218+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
1219+
attention_mask: Optional[torch.FloatTensor] = None,
1220+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1221+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
11921222
):
1193-
# TODO(Patrick, William) - attention mask is not used
11941223
output_states = ()
11951224

11961225
for resnet, attn in zip(self.resnets, self.attentions):
@@ -1205,33 +1234,32 @@ def custom_forward(*inputs):
12051234

12061235
return custom_forward
12071236

1208-
if is_torch_version(">=", "1.11.0"):
1209-
hidden_states = torch.utils.checkpoint.checkpoint(
1210-
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
1211-
)
1212-
hidden_states = torch.utils.checkpoint.checkpoint(
1213-
create_custom_forward(attn, return_dict=False),
1214-
hidden_states,
1215-
encoder_hidden_states,
1216-
cross_attention_kwargs,
1217-
use_reentrant=False,
1218-
)[0]
1219-
else:
1220-
hidden_states = torch.utils.checkpoint.checkpoint(
1221-
create_custom_forward(resnet), hidden_states, temb
1222-
)
1223-
hidden_states = torch.utils.checkpoint.checkpoint(
1224-
create_custom_forward(attn, return_dict=False),
1225-
hidden_states,
1226-
encoder_hidden_states,
1227-
cross_attention_kwargs,
1228-
)[0]
1237+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1238+
hidden_states = torch.utils.checkpoint.checkpoint(
1239+
create_custom_forward(resnet),
1240+
hidden_states,
1241+
temb,
1242+
**ckpt_kwargs,
1243+
)
1244+
hidden_states = torch.utils.checkpoint.checkpoint(
1245+
create_custom_forward(attn, return_dict=False),
1246+
hidden_states,
1247+
encoder_hidden_states,
1248+
None, # timestep
1249+
None, # class_labels
1250+
cross_attention_kwargs,
1251+
attention_mask,
1252+
encoder_attention_mask,
1253+
**ckpt_kwargs,
1254+
)[0]
12291255
else:
12301256
hidden_states = resnet(hidden_states, temb)
12311257
hidden_states = attn(
12321258
hidden_states,
12331259
encoder_hidden_states=encoder_hidden_states,
12341260
cross_attention_kwargs=cross_attention_kwargs,
1261+
attention_mask=attention_mask,
1262+
encoder_attention_mask=encoder_attention_mask,
12351263
return_dict=False,
12361264
)[0]
12371265

@@ -1414,15 +1442,15 @@ def __init__(
14141442

14151443
def forward(
14161444
self,
1417-
hidden_states,
1418-
res_hidden_states_tuple,
1419-
temb=None,
1420-
encoder_hidden_states=None,
1421-
cross_attention_kwargs=None,
1422-
upsample_size=None,
1423-
attention_mask=None,
1445+
hidden_states: torch.FloatTensor,
1446+
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
1447+
temb: Optional[torch.FloatTensor] = None,
1448+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
1449+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1450+
upsample_size: Optional[int] = None,
1451+
attention_mask: Optional[torch.FloatTensor] = None,
1452+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
14241453
):
1425-
# TODO(Patrick, William) - attention mask is not used
14261454
for resnet, attn in zip(self.resnets, self.attentions):
14271455
# pop res hidden states
14281456
res_hidden_states = res_hidden_states_tuple[-1]
@@ -1440,33 +1468,32 @@ def custom_forward(*inputs):
14401468

14411469
return custom_forward
14421470

1443-
if is_torch_version(">=", "1.11.0"):
1444-
hidden_states = torch.utils.checkpoint.checkpoint(
1445-
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
1446-
)
1447-
hidden_states = torch.utils.checkpoint.checkpoint(
1448-
create_custom_forward(attn, return_dict=False),
1449-
hidden_states,
1450-
encoder_hidden_states,
1451-
cross_attention_kwargs,
1452-
use_reentrant=False,
1453-
)[0]
1454-
else:
1455-
hidden_states = torch.utils.checkpoint.checkpoint(
1456-
create_custom_forward(resnet), hidden_states, temb
1457-
)
1458-
hidden_states = torch.utils.checkpoint.checkpoint(
1459-
create_custom_forward(attn, return_dict=False),
1460-
hidden_states,
1461-
encoder_hidden_states,
1462-
cross_attention_kwargs,
1463-
)[0]
1471+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1472+
hidden_states = torch.utils.checkpoint.checkpoint(
1473+
create_custom_forward(resnet),
1474+
hidden_states,
1475+
temb,
1476+
**ckpt_kwargs,
1477+
)
1478+
hidden_states = torch.utils.checkpoint.checkpoint(
1479+
create_custom_forward(attn, return_dict=False),
1480+
hidden_states,
1481+
encoder_hidden_states,
1482+
None, # timestep
1483+
None, # class_labels
1484+
cross_attention_kwargs,
1485+
attention_mask,
1486+
encoder_attention_mask,
1487+
**ckpt_kwargs,
1488+
)[0]
14641489
else:
14651490
hidden_states = resnet(hidden_states, temb)
14661491
hidden_states = attn(
14671492
hidden_states,
14681493
encoder_hidden_states=encoder_hidden_states,
14691494
cross_attention_kwargs=cross_attention_kwargs,
1495+
attention_mask=attention_mask,
1496+
encoder_attention_mask=encoder_attention_mask,
14701497
return_dict=False,
14711498
)[0]
14721499

@@ -1564,14 +1591,22 @@ def __init__(
15641591
self.resnets = nn.ModuleList(resnets)
15651592

15661593
def forward(
1567-
self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None
1568-
):
1594+
self,
1595+
hidden_states: torch.FloatTensor,
1596+
temb: Optional[torch.FloatTensor] = None,
1597+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
1598+
attention_mask: Optional[torch.FloatTensor] = None,
1599+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1600+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
1601+
) -> torch.FloatTensor:
15691602
hidden_states = self.resnets[0](hidden_states, temb)
15701603
for attn, resnet in zip(self.attentions, self.resnets[1:]):
15711604
hidden_states = attn(
15721605
hidden_states,
15731606
encoder_hidden_states=encoder_hidden_states,
15741607
cross_attention_kwargs=cross_attention_kwargs,
1608+
attention_mask=attention_mask,
1609+
encoder_attention_mask=encoder_attention_mask,
15751610
return_dict=False,
15761611
)[0]
15771612
hidden_states = resnet(hidden_states, temb)

0 commit comments

Comments
 (0)