Skip to content

Commit f55dbf3

Browse files
committed
fuse addition of attention mask
1 parent 9dc6822 commit f55dbf3

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

src/diffusers/models/cross_attention.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -205,17 +205,23 @@ def get_attention_scores(self, query, key, attention_mask=None):
205205
query = query.float()
206206
key = key.float()
207207

208+
beta = 0 if attention_mask is None else 1
209+
add = torch.empty(
210+
query.shape[0],
211+
query.shape[1],
212+
key.shape[1],
213+
dtype=query.dtype,
214+
device=query.device
215+
) if attention_mask is None else attention_mask
216+
208217
attention_scores = torch.baddbmm(
209-
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
218+
add,
210219
query,
211220
key.transpose(-1, -2),
212-
beta=0,
221+
beta=beta,
213222
alpha=self.scale,
214223
)
215224

216-
if attention_mask is not None:
217-
attention_scores = attention_scores + attention_mask
218-
219225
if self.upcast_softmax:
220226
attention_scores = attention_scores.float()
221227

0 commit comments

Comments
 (0)