Skip to content

Commit 354d626

Browse files
committed
Add key pre-transpose to sub-quadratic attention
1 parent 848605f commit 354d626

File tree

2 files changed

+27
-12
lines changed

2 files changed

+27
-12
lines changed

modules/sd_hijack_optimizations.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -230,10 +230,11 @@ def sub_quad_attention_forward(self, x, context=None, mask=None):
230230
del context, context_k, context_v, x
231231

232232
q = q.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
233-
k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
233+
k_t = k.transpose(1,2).unflatten(1, (h, -1)).flatten(end_dim=1)
234+
del k
234235
v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
235236

236-
x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
237+
x = sub_quad_attention(q, k_t, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training, key_needs_transpose=False)
237238

238239
x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2)
239240

@@ -243,7 +244,7 @@ def sub_quad_attention_forward(self, x, context=None, mask=None):
243244

244245
return x
245246

246-
def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True):
247+
def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True, key_needs_transpose=True):
247248
bytes_per_token = torch.finfo(q.dtype).bits//8
248249
batch_x_heads, q_tokens, _ = q.shape
249250
_, k_tokens, _ = k.shape
@@ -275,6 +276,7 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_
275276
kv_chunk_size=kv_chunk_size,
276277
kv_chunk_size_min = kv_chunk_size_min,
277278
use_checkpoint=use_checkpoint,
279+
key_needs_transpose=key_needs_transpose,
278280
)
279281

280282

modules/sub_quadratic_attention.py

+22-9
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,12 @@ def _summarize_chunk(
5151
key: Tensor,
5252
value: Tensor,
5353
scale: float,
54+
key_needs_transpose: bool,
5455
) -> AttnChunk:
5556
attn_weights = torch.baddbmm(
5657
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
5758
query,
58-
key.transpose(1,2),
59+
key.transpose(1,2) if key_needs_transpose else key,
5960
alpha=scale,
6061
beta=0,
6162
)
@@ -72,14 +73,18 @@ def _query_chunk_attention(
7273
value: Tensor,
7374
summarize_chunk: SummarizeChunk,
7475
kv_chunk_size: int,
76+
key_needs_transpose: bool,
7577
) -> Tensor:
76-
batch_x_heads, k_tokens, k_channels_per_head = key.shape
78+
if key_needs_transpose:
79+
batch_x_heads, k_tokens, k_channels_per_head = key.shape
80+
else:
81+
batch_x_heads, k_channels_per_head, k_tokens = key.shape
7782
_, _, v_channels_per_head = value.shape
7883

7984
def chunk_scanner(chunk_idx: int) -> AttnChunk:
8085
key_chunk = narrow_trunc(
8186
key,
82-
1,
87+
1 if key_needs_transpose else 2,
8388
chunk_idx,
8489
kv_chunk_size
8590
)
@@ -112,11 +117,12 @@ def _get_attention_scores_no_kv_chunking(
112117
key: Tensor,
113118
value: Tensor,
114119
scale: float,
120+
key_needs_transpose: bool,
115121
) -> Tensor:
116122
attn_scores = torch.baddbmm(
117123
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
118124
query,
119-
key.transpose(1,2),
125+
key.transpose(1,2) if key_needs_transpose else key,
120126
alpha=scale,
121127
beta=0,
122128
)
@@ -136,7 +142,8 @@ def efficient_dot_product_attention(
136142
query_chunk_size=1024,
137143
kv_chunk_size: Optional[int] = None,
138144
kv_chunk_size_min: Optional[int] = None,
139-
use_checkpoint=True,
145+
use_checkpoint: Optional[bool] = True,
146+
key_needs_transpose: Optional[bool] = True,
140147
):
141148
"""Computes efficient dot-product attention given query, key, and value.
142149
This is efficient version of attention presented in
@@ -151,12 +158,16 @@ def efficient_dot_product_attention(
151158
query_chunk_size: int: query chunks size
152159
kv_chunk_size: Optional[int]: key/value chunks size. if None: defaults to sqrt(key_tokens)
153160
kv_chunk_size_min: Optional[int]: key/value minimum chunk size. only considered when kv_chunk_size is None. changes `sqrt(key_tokens)` into `max(sqrt(key_tokens), kv_chunk_size_min)`, to ensure our chunk sizes don't get too small (smaller chunks = more chunks = less concurrent work done).
154-
use_checkpoint: bool: whether to use checkpointing (recommended True for training, False for inference)
161+
use_checkpoint: Optional[bool]: whether to use checkpointing (recommended True for training, False for inference)
162+
key_needs_transpose: Optional[bool]: whether key needs a transpose. defaults to True
155163
Returns:
156164
Output of shape `[batch * num_heads, query_tokens, channels_per_head]`.
157165
"""
158166
batch_x_heads, q_tokens, q_channels_per_head = query.shape
159-
_, k_tokens, _ = key.shape
167+
if key_needs_transpose:
168+
_, k_tokens, _ = key.shape
169+
else:
170+
_, _, k_tokens = key.shape
160171
scale = q_channels_per_head ** -0.5
161172

162173
kv_chunk_size = min(kv_chunk_size or int(math.sqrt(k_tokens)), k_tokens)
@@ -171,17 +182,19 @@ def get_query_chunk(chunk_idx: int) -> Tensor:
171182
min(query_chunk_size, q_tokens)
172183
)
173184

174-
summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale)
185+
summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale, key_needs_transpose=key_needs_transpose)
175186
summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
176187
compute_query_chunk_attn: ComputeQueryChunkAttn = partial(
177188
_get_attention_scores_no_kv_chunking,
178-
scale=scale
189+
scale=scale,
190+
key_needs_transpose=key_needs_transpose,
179191
) if k_tokens <= kv_chunk_size else (
180192
# fast-path for when there's just 1 key-value chunk per query chunk (this is just sliced attention btw)
181193
partial(
182194
_query_chunk_attention,
183195
kv_chunk_size=kv_chunk_size,
184196
summarize_chunk=summarize_chunk,
197+
key_needs_transpose=key_needs_transpose,
185198
)
186199
)
187200

0 commit comments

Comments
 (0)