Skip to content

Commit 9dc6822

Browse files
committed
pre-transpose key, rather than transposing it then undoing the transpose during the matmul
1 parent 0eafb95 commit 9dc6822

File tree

2 files changed

+23
-22
lines changed

2 files changed

+23
-22
lines changed

src/diffusers/models/cross_attention.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -303,19 +303,20 @@ def __call__(
303303
value = attn.to_v(encoder_hidden_states)
304304

305305
query = query.unflatten(-1, (attn.heads, -1)).transpose(1,2).flatten(end_dim=1)
306-
key = key.unflatten(-1, (attn.heads, -1)).transpose(1,2).flatten(end_dim=1)
306+
key_t = key.transpose(1,2).unflatten(1, (attn.heads, -1)).flatten(end_dim=1)
307+
del key
307308
value = value.unflatten(-1, (attn.heads, -1)).transpose(1,2).flatten(end_dim=1)
308309

309310
dtype = query.dtype
310311
# TODO: do we still need to do *everything* in float32, given how we delay the division?
311312
# TODO: do we need to support upcast_softmax too? SD 2.1 seems to work without it
312313
if attn.upcast_attention:
313314
query = query.float()
314-
key = key.float()
315+
key_t = key_t.float()
315316

316317
bytes_per_token = torch.finfo(query.dtype).bits//8
317318
batch_x_heads, q_tokens, _ = query.shape
318-
_, k_tokens, _ = key.shape
319+
_, _, k_tokens = key_t.shape
319320
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
320321

321322
query_chunk_size = self.query_chunk_size
@@ -329,7 +330,7 @@ def __call__(
329330

330331
hidden_states = efficient_dot_product_attention(
331332
query,
332-
key,
333+
key_t,
333334
value,
334335
query_chunk_size=query_chunk_size,
335336
kv_chunk_size=kv_chunk_size,

src/diffusers/models/sub_quadratic_attention.py

+18-18
Original file line numberDiff line numberDiff line change
@@ -26,28 +26,28 @@ class SummarizeChunk(Protocol):
2626
@staticmethod
2727
def __call__(
2828
query: Tensor,
29-
key: Tensor,
29+
key_t: Tensor,
3030
value: Tensor,
3131
) -> AttnChunk: ...
3232

3333
class ComputeQueryChunkAttn(Protocol):
3434
@staticmethod
3535
def __call__(
3636
query: Tensor,
37-
key: Tensor,
37+
key_t: Tensor,
3838
value: Tensor,
3939
) -> Tensor: ...
4040

4141
def _summarize_chunk(
4242
query: Tensor,
43-
key: Tensor,
43+
key_t: Tensor,
4444
value: Tensor,
4545
scale: float,
4646
) -> AttnChunk:
4747
attn_weights = torch.baddbmm(
4848
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
4949
query,
50-
key.transpose(1,2),
50+
key_t,
5151
alpha=scale,
5252
beta=0,
5353
)
@@ -60,19 +60,19 @@ def _summarize_chunk(
6060

6161
def _query_chunk_attention(
6262
query: Tensor,
63-
key: Tensor,
63+
key_t: Tensor,
6464
value: Tensor,
6565
summarize_chunk: SummarizeChunk,
6666
kv_chunk_size: int,
6767
) -> Tensor:
68-
batch_x_heads, k_tokens, k_channels_per_head = key.shape
68+
batch_x_heads, k_channels_per_head, k_tokens = key_t.shape
6969
_, _, v_channels_per_head = value.shape
7070

7171
def chunk_scanner(chunk_idx: int) -> AttnChunk:
7272
key_chunk = dynamic_slice(
73-
key,
74-
(0, chunk_idx, 0),
75-
(batch_x_heads, kv_chunk_size, k_channels_per_head)
73+
key_t,
74+
(0, 0, chunk_idx),
75+
(batch_x_heads, k_channels_per_head, kv_chunk_size)
7676
)
7777
value_chunk = dynamic_slice(
7878
value,
@@ -99,14 +99,14 @@ def chunk_scanner(chunk_idx: int) -> AttnChunk:
9999
# TODO: refactor CrossAttention#get_attention_scores to share code with this
100100
def _get_attention_scores_no_kv_chunking(
101101
query: Tensor,
102-
key: Tensor,
102+
key_t: Tensor,
103103
value: Tensor,
104104
scale: float,
105105
) -> Tensor:
106106
attn_scores = torch.baddbmm(
107107
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
108108
query,
109-
key.transpose(1,2),
109+
key_t,
110110
alpha=scale,
111111
beta=0,
112112
)
@@ -121,21 +121,21 @@ class ScannedChunk(NamedTuple):
121121

122122
def efficient_dot_product_attention(
123123
query: Tensor,
124-
key: Tensor,
124+
key_t: Tensor,
125125
value: Tensor,
126126
query_chunk_size=1024,
127127
kv_chunk_size: Optional[int] = None,
128128
kv_chunk_size_min: Optional[int] = None,
129129
use_checkpoint=True,
130130
):
131-
"""Computes efficient dot-product attention given query, key, and value.
131+
"""Computes efficient dot-product attention given query, transposed key, and value.
132132
This is efficient version of attention presented in
133133
https://arxiv.org/abs/2112.05682v2 which comes with O(sqrt(n)) memory requirements.
134134
Args:
135135
query: queries for calculating attention with shape of
136136
`[batch * num_heads, tokens, channels_per_head]`.
137-
key: keys for calculating attention with shape of
138-
`[batch * num_heads, tokens, channels_per_head]`.
137+
key_t: keys for calculating attention with shape of
138+
`[batch * num_heads, channels_per_head, tokens]`.
139139
value: values to be used in attention with shape of
140140
`[batch * num_heads, tokens, channels_per_head]`.
141141
query_chunk_size: int: query chunks size
@@ -146,7 +146,7 @@ def efficient_dot_product_attention(
146146
Output of shape `[batch * num_heads, query_tokens, channels_per_head]`.
147147
"""
148148
batch_x_heads, q_tokens, q_channels_per_head = query.shape
149-
_, k_tokens, _ = key.shape
149+
_, _, k_tokens = key_t.shape
150150
scale = q_channels_per_head ** -0.5
151151

152152
kv_chunk_size = min(kv_chunk_size or int(math.sqrt(k_tokens)), k_tokens)
@@ -178,7 +178,7 @@ def get_query_chunk(chunk_idx: int) -> Tensor:
178178
# fast-path for when there's just 1 query chunk
179179
return compute_query_chunk_attn(
180180
query=query,
181-
key=key,
181+
key_t=key_t,
182182
value=value,
183183
)
184184

@@ -187,7 +187,7 @@ def get_query_chunk(chunk_idx: int) -> Tensor:
187187
res = torch.cat([
188188
compute_query_chunk_attn(
189189
query=get_query_chunk(i * query_chunk_size),
190-
key=key,
190+
key_t=key_t,
191191
value=value,
192192
) for i in range(math.ceil(q_tokens / query_chunk_size))
193193
], dim=1)

0 commit comments

Comments
 (0)