@@ -173,7 +173,8 @@ def set_use_memory_efficient_attention_xformers(
173
173
LORA_ATTENTION_PROCESSORS ,
174
174
)
175
175
is_custom_diffusion = hasattr (self , "processor" ) and isinstance (
176
- self .processor , (CustomDiffusionAttnProcessor , CustomDiffusionXFormersAttnProcessor )
176
+ self .processor ,
177
+ (CustomDiffusionAttnProcessor , CustomDiffusionXFormersAttnProcessor , CustomDiffusionAttnProcessor2_0 ),
177
178
)
178
179
is_added_kv_processor = hasattr (self , "processor" ) and isinstance (
179
180
self .processor ,
@@ -261,7 +262,12 @@ def set_use_memory_efficient_attention_xformers(
261
262
processor .load_state_dict (self .processor .state_dict ())
262
263
processor .to (self .processor .to_q_lora .up .weight .device )
263
264
elif is_custom_diffusion :
264
- processor = CustomDiffusionAttnProcessor (
265
+ attn_processor_class = (
266
+ CustomDiffusionAttnProcessor2_0
267
+ if hasattr (F , "scaled_dot_product_attention" )
268
+ else CustomDiffusionAttnProcessor
269
+ )
270
+ processor = attn_processor_class (
265
271
train_kv = self .processor .train_kv ,
266
272
train_q_out = self .processor .train_q_out ,
267
273
hidden_size = self .processor .hidden_size ,
@@ -1156,6 +1162,111 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
1156
1162
return hidden_states
1157
1163
1158
1164
1165
+ class CustomDiffusionAttnProcessor2_0 (nn .Module ):
1166
+ r"""
1167
+ Processor for implementing attention for the Custom Diffusion method using PyTorch 2.0’s memory-efficient scaled
1168
+ dot-product attention.
1169
+
1170
+ Args:
1171
+ train_kv (`bool`, defaults to `True`):
1172
+ Whether to newly train the key and value matrices corresponding to the text features.
1173
+ train_q_out (`bool`, defaults to `True`):
1174
+ Whether to newly train query matrices corresponding to the latent image features.
1175
+ hidden_size (`int`, *optional*, defaults to `None`):
1176
+ The hidden size of the attention layer.
1177
+ cross_attention_dim (`int`, *optional*, defaults to `None`):
1178
+ The number of channels in the `encoder_hidden_states`.
1179
+ out_bias (`bool`, defaults to `True`):
1180
+ Whether to include the bias parameter in `train_q_out`.
1181
+ dropout (`float`, *optional*, defaults to 0.0):
1182
+ The dropout probability to use.
1183
+ """
1184
+
1185
+ def __init__ (
1186
+ self ,
1187
+ train_kv = True ,
1188
+ train_q_out = True ,
1189
+ hidden_size = None ,
1190
+ cross_attention_dim = None ,
1191
+ out_bias = True ,
1192
+ dropout = 0.0 ,
1193
+ ):
1194
+ super ().__init__ ()
1195
+ self .train_kv = train_kv
1196
+ self .train_q_out = train_q_out
1197
+
1198
+ self .hidden_size = hidden_size
1199
+ self .cross_attention_dim = cross_attention_dim
1200
+
1201
+ # `_custom_diffusion` id for easy serialization and loading.
1202
+ if self .train_kv :
1203
+ self .to_k_custom_diffusion = nn .Linear (cross_attention_dim or hidden_size , hidden_size , bias = False )
1204
+ self .to_v_custom_diffusion = nn .Linear (cross_attention_dim or hidden_size , hidden_size , bias = False )
1205
+ if self .train_q_out :
1206
+ self .to_q_custom_diffusion = nn .Linear (hidden_size , hidden_size , bias = False )
1207
+ self .to_out_custom_diffusion = nn .ModuleList ([])
1208
+ self .to_out_custom_diffusion .append (nn .Linear (hidden_size , hidden_size , bias = out_bias ))
1209
+ self .to_out_custom_diffusion .append (nn .Dropout (dropout ))
1210
+
1211
+ def __call__ (self , attn : Attention , hidden_states , encoder_hidden_states = None , attention_mask = None ):
1212
+ batch_size , sequence_length , _ = hidden_states .shape
1213
+ attention_mask = attn .prepare_attention_mask (attention_mask , sequence_length , batch_size )
1214
+ if self .train_q_out :
1215
+ query = self .to_q_custom_diffusion (hidden_states )
1216
+ else :
1217
+ query = attn .to_q (hidden_states )
1218
+
1219
+ if encoder_hidden_states is None :
1220
+ crossattn = False
1221
+ encoder_hidden_states = hidden_states
1222
+ else :
1223
+ crossattn = True
1224
+ if attn .norm_cross :
1225
+ encoder_hidden_states = attn .norm_encoder_hidden_states (encoder_hidden_states )
1226
+
1227
+ if self .train_kv :
1228
+ key = self .to_k_custom_diffusion (encoder_hidden_states )
1229
+ value = self .to_v_custom_diffusion (encoder_hidden_states )
1230
+ else :
1231
+ key = attn .to_k (encoder_hidden_states )
1232
+ value = attn .to_v (encoder_hidden_states )
1233
+
1234
+ if crossattn :
1235
+ detach = torch .ones_like (key )
1236
+ detach [:, :1 , :] = detach [:, :1 , :] * 0.0
1237
+ key = detach * key + (1 - detach ) * key .detach ()
1238
+ value = detach * value + (1 - detach ) * value .detach ()
1239
+
1240
+ inner_dim = hidden_states .shape [- 1 ]
1241
+
1242
+ head_dim = inner_dim // attn .heads
1243
+ query = query .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1244
+ key = key .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1245
+ value = value .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1246
+
1247
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
1248
+ # TODO: add support for attn.scale when we move to Torch 2.1
1249
+ hidden_states = F .scaled_dot_product_attention (
1250
+ query , key , value , attn_mask = attention_mask , dropout_p = 0.0 , is_causal = False
1251
+ )
1252
+
1253
+ hidden_states = hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads * head_dim )
1254
+ hidden_states = hidden_states .to (query .dtype )
1255
+
1256
+ if self .train_q_out :
1257
+ # linear proj
1258
+ hidden_states = self .to_out_custom_diffusion [0 ](hidden_states )
1259
+ # dropout
1260
+ hidden_states = self .to_out_custom_diffusion [1 ](hidden_states )
1261
+ else :
1262
+ # linear proj
1263
+ hidden_states = attn .to_out [0 ](hidden_states )
1264
+ # dropout
1265
+ hidden_states = attn .to_out [1 ](hidden_states )
1266
+
1267
+ return hidden_states
1268
+
1269
+
1159
1270
class SlicedAttnProcessor :
1160
1271
r"""
1161
1272
Processor for implementing sliced attention.
@@ -1639,6 +1750,7 @@ def __call__(self, attn: Attention, hidden_states, *args, **kwargs):
1639
1750
XFormersAttnAddedKVProcessor ,
1640
1751
CustomDiffusionAttnProcessor ,
1641
1752
CustomDiffusionXFormersAttnProcessor ,
1753
+ CustomDiffusionAttnProcessor2_0 ,
1642
1754
# depraceted
1643
1755
LoRAAttnProcessor ,
1644
1756
LoRAAttnProcessor2_0 ,
0 commit comments