@@ -721,13 +721,18 @@ def forward(
721
721
cross_attention_kwargs : Optional [Dict [str , Any ]] = None ,
722
722
down_block_additional_residuals : Optional [Tuple [torch .Tensor ]] = None ,
723
723
mid_block_additional_residual : Optional [torch .Tensor ] = None ,
724
+ encoder_attention_mask : Optional [torch .Tensor ] = None ,
724
725
return_dict : bool = True ,
725
726
) -> Union [UNet2DConditionOutput , Tuple ]:
726
727
r"""
727
728
Args:
728
729
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
729
730
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
730
731
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
732
+ encoder_attention_mask (`torch.Tensor`):
733
+ (batch, sequence_length) cross-attention mask, applied to encoder_hidden_states. True = keep, False =
734
+ discard. Mask will be converted into a bias, which adds large negative values to attention scores
735
+ corresponding to "discard" tokens.
731
736
return_dict (`bool`, *optional*, defaults to `True`):
732
737
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
733
738
cross_attention_kwargs (`dict`, *optional*):
@@ -754,11 +759,27 @@ def forward(
754
759
logger .info ("Forward upsample size to force interpolation output size." )
755
760
forward_upsample_size = True
756
761
757
- # prepare attention_mask
762
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
763
+ # expects mask of shape:
764
+ # [batch, key_tokens]
765
+ # adds singleton query_tokens dimension:
766
+ # [batch, 1, key_tokens]
767
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
768
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
769
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
758
770
if attention_mask is not None :
771
+ # assume that mask is expressed as:
772
+ # (1 = keep, 0 = discard)
773
+ # convert mask into a bias that can be added to attention scores:
774
+ # (keep = +0, discard = -10000.0)
759
775
attention_mask = (1 - attention_mask .to (sample .dtype )) * - 10000.0
760
776
attention_mask = attention_mask .unsqueeze (1 )
761
777
778
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
779
+ if encoder_attention_mask is not None :
780
+ encoder_attention_mask = (1 - encoder_attention_mask .to (sample .dtype )) * - 10000.0
781
+ encoder_attention_mask = encoder_attention_mask .unsqueeze (1 )
782
+
762
783
# 0. center input if necessary
763
784
if self .config .center_input_sample :
764
785
sample = 2 * sample - 1.0
@@ -830,6 +851,7 @@ def forward(
830
851
encoder_hidden_states = encoder_hidden_states ,
831
852
attention_mask = attention_mask ,
832
853
cross_attention_kwargs = cross_attention_kwargs ,
854
+ encoder_attention_mask = encoder_attention_mask ,
833
855
)
834
856
else :
835
857
sample , res_samples = downsample_block (hidden_states = sample , temb = emb )
@@ -855,6 +877,7 @@ def forward(
855
877
encoder_hidden_states = encoder_hidden_states ,
856
878
attention_mask = attention_mask ,
857
879
cross_attention_kwargs = cross_attention_kwargs ,
880
+ encoder_attention_mask = encoder_attention_mask ,
858
881
)
859
882
860
883
if mid_block_additional_residual is not None :
@@ -881,6 +904,7 @@ def forward(
881
904
cross_attention_kwargs = cross_attention_kwargs ,
882
905
upsample_size = upsample_size ,
883
906
attention_mask = attention_mask ,
907
+ encoder_attention_mask = encoder_attention_mask ,
884
908
)
885
909
else :
886
910
sample = upsample_block (
@@ -1188,9 +1212,14 @@ def __init__(
1188
1212
self .gradient_checkpointing = False
1189
1213
1190
1214
def forward (
1191
- self , hidden_states , temb = None , encoder_hidden_states = None , attention_mask = None , cross_attention_kwargs = None
1215
+ self ,
1216
+ hidden_states : torch .FloatTensor ,
1217
+ temb : Optional [torch .FloatTensor ] = None ,
1218
+ encoder_hidden_states : Optional [torch .FloatTensor ] = None ,
1219
+ attention_mask : Optional [torch .FloatTensor ] = None ,
1220
+ cross_attention_kwargs : Optional [Dict [str , Any ]] = None ,
1221
+ encoder_attention_mask : Optional [torch .FloatTensor ] = None ,
1192
1222
):
1193
- # TODO(Patrick, William) - attention mask is not used
1194
1223
output_states = ()
1195
1224
1196
1225
for resnet , attn in zip (self .resnets , self .attentions ):
@@ -1205,33 +1234,32 @@ def custom_forward(*inputs):
1205
1234
1206
1235
return custom_forward
1207
1236
1208
- if is_torch_version (">=" , "1.11.0" ):
1209
- hidden_states = torch .utils .checkpoint .checkpoint (
1210
- create_custom_forward (resnet ), hidden_states , temb , use_reentrant = False
1211
- )
1212
- hidden_states = torch .utils .checkpoint .checkpoint (
1213
- create_custom_forward (attn , return_dict = False ),
1214
- hidden_states ,
1215
- encoder_hidden_states ,
1216
- cross_attention_kwargs ,
1217
- use_reentrant = False ,
1218
- )[0 ]
1219
- else :
1220
- hidden_states = torch .utils .checkpoint .checkpoint (
1221
- create_custom_forward (resnet ), hidden_states , temb
1222
- )
1223
- hidden_states = torch .utils .checkpoint .checkpoint (
1224
- create_custom_forward (attn , return_dict = False ),
1225
- hidden_states ,
1226
- encoder_hidden_states ,
1227
- cross_attention_kwargs ,
1228
- )[0 ]
1237
+ ckpt_kwargs : Dict [str , Any ] = {"use_reentrant" : False } if is_torch_version (">=" , "1.11.0" ) else {}
1238
+ hidden_states = torch .utils .checkpoint .checkpoint (
1239
+ create_custom_forward (resnet ),
1240
+ hidden_states ,
1241
+ temb ,
1242
+ ** ckpt_kwargs ,
1243
+ )
1244
+ hidden_states = torch .utils .checkpoint .checkpoint (
1245
+ create_custom_forward (attn , return_dict = False ),
1246
+ hidden_states ,
1247
+ encoder_hidden_states ,
1248
+ None , # timestep
1249
+ None , # class_labels
1250
+ cross_attention_kwargs ,
1251
+ attention_mask ,
1252
+ encoder_attention_mask ,
1253
+ ** ckpt_kwargs ,
1254
+ )[0 ]
1229
1255
else :
1230
1256
hidden_states = resnet (hidden_states , temb )
1231
1257
hidden_states = attn (
1232
1258
hidden_states ,
1233
1259
encoder_hidden_states = encoder_hidden_states ,
1234
1260
cross_attention_kwargs = cross_attention_kwargs ,
1261
+ attention_mask = attention_mask ,
1262
+ encoder_attention_mask = encoder_attention_mask ,
1235
1263
return_dict = False ,
1236
1264
)[0 ]
1237
1265
@@ -1414,15 +1442,15 @@ def __init__(
1414
1442
1415
1443
def forward (
1416
1444
self ,
1417
- hidden_states ,
1418
- res_hidden_states_tuple ,
1419
- temb = None ,
1420
- encoder_hidden_states = None ,
1421
- cross_attention_kwargs = None ,
1422
- upsample_size = None ,
1423
- attention_mask = None ,
1445
+ hidden_states : torch .FloatTensor ,
1446
+ res_hidden_states_tuple : Tuple [torch .FloatTensor , ...],
1447
+ temb : Optional [torch .FloatTensor ] = None ,
1448
+ encoder_hidden_states : Optional [torch .FloatTensor ] = None ,
1449
+ cross_attention_kwargs : Optional [Dict [str , Any ]] = None ,
1450
+ upsample_size : Optional [int ] = None ,
1451
+ attention_mask : Optional [torch .FloatTensor ] = None ,
1452
+ encoder_attention_mask : Optional [torch .FloatTensor ] = None ,
1424
1453
):
1425
- # TODO(Patrick, William) - attention mask is not used
1426
1454
for resnet , attn in zip (self .resnets , self .attentions ):
1427
1455
# pop res hidden states
1428
1456
res_hidden_states = res_hidden_states_tuple [- 1 ]
@@ -1440,33 +1468,32 @@ def custom_forward(*inputs):
1440
1468
1441
1469
return custom_forward
1442
1470
1443
- if is_torch_version (">=" , "1.11.0" ):
1444
- hidden_states = torch .utils .checkpoint .checkpoint (
1445
- create_custom_forward (resnet ), hidden_states , temb , use_reentrant = False
1446
- )
1447
- hidden_states = torch .utils .checkpoint .checkpoint (
1448
- create_custom_forward (attn , return_dict = False ),
1449
- hidden_states ,
1450
- encoder_hidden_states ,
1451
- cross_attention_kwargs ,
1452
- use_reentrant = False ,
1453
- )[0 ]
1454
- else :
1455
- hidden_states = torch .utils .checkpoint .checkpoint (
1456
- create_custom_forward (resnet ), hidden_states , temb
1457
- )
1458
- hidden_states = torch .utils .checkpoint .checkpoint (
1459
- create_custom_forward (attn , return_dict = False ),
1460
- hidden_states ,
1461
- encoder_hidden_states ,
1462
- cross_attention_kwargs ,
1463
- )[0 ]
1471
+ ckpt_kwargs : Dict [str , Any ] = {"use_reentrant" : False } if is_torch_version (">=" , "1.11.0" ) else {}
1472
+ hidden_states = torch .utils .checkpoint .checkpoint (
1473
+ create_custom_forward (resnet ),
1474
+ hidden_states ,
1475
+ temb ,
1476
+ ** ckpt_kwargs ,
1477
+ )
1478
+ hidden_states = torch .utils .checkpoint .checkpoint (
1479
+ create_custom_forward (attn , return_dict = False ),
1480
+ hidden_states ,
1481
+ encoder_hidden_states ,
1482
+ None , # timestep
1483
+ None , # class_labels
1484
+ cross_attention_kwargs ,
1485
+ attention_mask ,
1486
+ encoder_attention_mask ,
1487
+ ** ckpt_kwargs ,
1488
+ )[0 ]
1464
1489
else :
1465
1490
hidden_states = resnet (hidden_states , temb )
1466
1491
hidden_states = attn (
1467
1492
hidden_states ,
1468
1493
encoder_hidden_states = encoder_hidden_states ,
1469
1494
cross_attention_kwargs = cross_attention_kwargs ,
1495
+ attention_mask = attention_mask ,
1496
+ encoder_attention_mask = encoder_attention_mask ,
1470
1497
return_dict = False ,
1471
1498
)[0 ]
1472
1499
@@ -1564,14 +1591,22 @@ def __init__(
1564
1591
self .resnets = nn .ModuleList (resnets )
1565
1592
1566
1593
def forward (
1567
- self , hidden_states , temb = None , encoder_hidden_states = None , attention_mask = None , cross_attention_kwargs = None
1568
- ):
1594
+ self ,
1595
+ hidden_states : torch .FloatTensor ,
1596
+ temb : Optional [torch .FloatTensor ] = None ,
1597
+ encoder_hidden_states : Optional [torch .FloatTensor ] = None ,
1598
+ attention_mask : Optional [torch .FloatTensor ] = None ,
1599
+ cross_attention_kwargs : Optional [Dict [str , Any ]] = None ,
1600
+ encoder_attention_mask : Optional [torch .FloatTensor ] = None ,
1601
+ ) -> torch .FloatTensor :
1569
1602
hidden_states = self .resnets [0 ](hidden_states , temb )
1570
1603
for attn , resnet in zip (self .attentions , self .resnets [1 :]):
1571
1604
hidden_states = attn (
1572
1605
hidden_states ,
1573
1606
encoder_hidden_states = encoder_hidden_states ,
1574
1607
cross_attention_kwargs = cross_attention_kwargs ,
1608
+ attention_mask = attention_mask ,
1609
+ encoder_attention_mask = encoder_attention_mask ,
1575
1610
return_dict = False ,
1576
1611
)[0 ]
1577
1612
hidden_states = resnet (hidden_states , temb )
0 commit comments