Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 3c92600

Browse files
committedDec 30, 2022
pre-transpose key, rather than transposing it then undoing the transpose during the matmul
1 parent 0eafb95 commit 3c92600

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed
 

‎src/diffusers/models/cross_attention.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -303,19 +303,20 @@ def __call__(
303303
value = attn.to_v(encoder_hidden_states)
304304

305305
query = query.unflatten(-1, (attn.heads, -1)).transpose(1,2).flatten(end_dim=1)
306-
key = key.unflatten(-1, (attn.heads, -1)).transpose(1,2).flatten(end_dim=1)
306+
key_t = key.transpose(1,2).unflatten(1, (attn.heads, -1)).flatten(end_dim=1)
307+
del key
307308
value = value.unflatten(-1, (attn.heads, -1)).transpose(1,2).flatten(end_dim=1)
308309

309310
dtype = query.dtype
310311
# TODO: do we still need to do *everything* in float32, given how we delay the division?
311312
# TODO: do we need to support upcast_softmax too? SD 2.1 seems to work without it
312313
if attn.upcast_attention:
313314
query = query.float()
314-
key = key.float()
315+
key_t = key_t.float()
315316

316317
bytes_per_token = torch.finfo(query.dtype).bits//8
317318
batch_x_heads, q_tokens, _ = query.shape
318-
_, k_tokens, _ = key.shape
319+
_, _, k_tokens = key_t.shape
319320
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
320321

321322
query_chunk_size = self.query_chunk_size
@@ -329,7 +330,7 @@ def __call__(
329330

330331
hidden_states = efficient_dot_product_attention(
331332
query,
332-
key,
333+
key_t,
333334
value,
334335
query_chunk_size=query_chunk_size,
335336
kv_chunk_size=kv_chunk_size,

0 commit comments

Comments
 (0)
Please sign in to comment.