Skip to content

Commit 1497650

Browse files
fuse attention mask (#2111)
* fuse attention mask * lint * use 0 beta when no attention mask re: @Birch-san
1 parent 96af5bf commit 1497650

File tree

1 file changed

+13
-6
lines changed

1 file changed

+13
-6
lines changed

Diff for: src/diffusers/models/cross_attention.py

+13-6
Original file line numberDiff line numberDiff line change
@@ -185,17 +185,23 @@ def get_attention_scores(self, query, key, attention_mask=None):
185185
query = query.float()
186186
key = key.float()
187187

188+
if attention_mask is None:
189+
baddbmm_input = torch.empty(
190+
query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
191+
)
192+
beta = 0
193+
else:
194+
baddbmm_input = attention_mask
195+
beta = 1
196+
188197
attention_scores = torch.baddbmm(
189-
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
198+
baddbmm_input,
190199
query,
191200
key.transpose(-1, -2),
192-
beta=0,
201+
beta=beta,
193202
alpha=self.scale,
194203
)
195204

196-
if attention_mask is not None:
197-
attention_scores = attention_scores + attention_mask
198-
199205
if self.upcast_softmax:
200206
attention_scores = attention_scores.float()
201207

@@ -228,11 +234,12 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No
228234
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
229235

230236
query = attn.to_q(hidden_states)
231-
query = attn.head_to_batch_dim(query)
232237

233238
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
234239
key = attn.to_k(encoder_hidden_states)
235240
value = attn.to_v(encoder_hidden_states)
241+
242+
query = attn.head_to_batch_dim(query)
236243
key = attn.head_to_batch_dim(key)
237244
value = attn.head_to_batch_dim(value)
238245

0 commit comments

Comments
 (0)