Skip to content

Commit 7f4eb85

Browse files
committed
Add Birch-san's sub-quadratic attention implementation
1 parent 4af3ca5 commit 7f4eb85

5 files changed

+194
-0
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ The documentation was moved from this README over to the project's [wiki](https:
139139
- Ideas for optimizations - https://github.com/basujindal/stable-diffusion
140140
- Cross Attention layer optimization - Doggettx - https://github.com/Doggettx/stable-diffusion, original idea for prompt editing.
141141
- Cross Attention layer optimization - InvokeAI, lstein - https://github.com/invoke-ai/InvokeAI (originally http://github.com/lstein/stable-diffusion)
142+
- Sub-quadratic Cross Attention layer optimization - Alex Birch (https://github.com/Birch-san), Amin Rezaei (https://github.com/AminRezaei0x443)
142143
- Textual Inversion - Rinon Gal - https://github.com/rinongal/textual_inversion (we're not using his code, but we are using his ideas).
143144
- Idea for SD upscale - https://github.com/jquesnelle/txt2imghd
144145
- Noise generation for outpainting mk2 - https://github.com/parlance-zz/g-diffuser-bot

modules/sd_hijack.py

+3
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ def apply_optimizations():
4040
print("Applying xformers cross attention optimization.")
4141
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
4242
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
43+
elif cmd_opts.opt_sub_quad_attention:
44+
print("Applying sub-quadratic cross attention optimization.")
45+
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward
4346
elif cmd_opts.opt_split_attention_v1:
4447
print("Applying v1 cross attention optimization.")
4548
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1

modules/sd_hijack_optimizations.py

+40
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from modules import shared
1313
from modules.hypernetworks import hypernetwork
1414

15+
from .sub_quadratic_attention import efficient_dot_product_attention
16+
1517

1618
if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
1719
try:
@@ -215,6 +217,44 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
215217

216218
# -- End of code from https://github.com/invoke-ai/InvokeAI --
217219

220+
221+
# Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1
222+
def sub_quad_attention_forward(self, x, context=None, mask=None):
223+
assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor."
224+
225+
h = self.heads
226+
227+
q = self.to_q(x)
228+
context = default(context, x)
229+
230+
context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
231+
k = self.to_k(context_k)
232+
v = self.to_v(context_v)
233+
del context, context_k, context_v, x
234+
235+
q = q.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
236+
k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
237+
v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
238+
239+
dtype = q.dtype
240+
# TODO: do we still need to do *everything* in float32, given how we delay the division?
241+
# TODO: do we need to support upcast_softmax too? SD 2.1 seems to work without it
242+
243+
x = efficient_dot_product_attention(
244+
q,
245+
k,
246+
v,
247+
)
248+
x = x.to(dtype)
249+
250+
x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2)
251+
252+
out_proj, dropout = self.to_out
253+
x = out_proj(x)
254+
x = dropout(x)
255+
256+
return x
257+
218258
def xformers_attention_forward(self, x, context=None, mask=None):
219259
h = self.heads
220260
q_in = self.to_q(x)

modules/shared.py

+1
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
5757
parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything")
5858
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.")
59+
parser.add_argument("--opt-sub-quad-attention", action='store_true', help="enable memory efficient sub-quadratic cross-attention layer optimization")
5960
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
6061
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
6162
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")

modules/sub_quadratic_attention.py

+149
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
# original source:
2+
# https://github.com/AminRezaei0x443/memory-efficient-attention/blob/1bc0d9e6ac5f82ea43a375135c4e1d3896ee1694/memory_efficient_attention/attention_torch.py
3+
# license:
4+
# unspecified
5+
# credit:
6+
# Amin Rezaei (original author)
7+
# Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks)
8+
# implementation of:
9+
# Self-attention Does Not Need O(n2) Memory":
10+
# https://arxiv.org/abs/2112.05682v2
11+
12+
from functools import partial
13+
import torch
14+
from torch import Tensor
15+
from torch.utils.checkpoint import checkpoint
16+
import math
17+
from typing import Optional, NamedTuple, Protocol, List
18+
19+
def dynamic_slice(
20+
x: Tensor,
21+
starts: List[int],
22+
sizes: List[int],
23+
) -> Tensor:
24+
slicing = [slice(start, start + size) for start, size in zip(starts, sizes)]
25+
return x[slicing]
26+
27+
class AttnChunk(NamedTuple):
28+
exp_values: Tensor
29+
exp_weights_sum: Tensor
30+
max_score: Tensor
31+
32+
class SummarizeChunk(Protocol):
33+
@staticmethod
34+
def __call__(
35+
query: Tensor,
36+
key: Tensor,
37+
value: Tensor,
38+
) -> AttnChunk: ...
39+
40+
def _query_chunk_attention(
41+
query: Tensor,
42+
key: Tensor,
43+
value: Tensor,
44+
key_chunk_size: Optional[int] = None,
45+
use_checkpoint = True,
46+
):
47+
batch_x_heads, k_tokens, k_channels_per_head = key.shape
48+
_, _, v_channels_per_head = value.shape
49+
key_chunk_size = min(key_chunk_size or int(math.sqrt(k_tokens)), k_tokens)
50+
scale = k_channels_per_head ** -0.5
51+
52+
def summarize_chunk(
53+
query: Tensor,
54+
key: Tensor,
55+
value: Tensor,
56+
) -> AttnChunk:
57+
attn_weights = torch.baddbmm(
58+
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
59+
query,
60+
key.transpose(1,2),
61+
alpha=scale,
62+
beta=0,
63+
)
64+
max_score, _ = torch.max(attn_weights, -1, keepdim=True)
65+
max_score = max_score.detach()
66+
exp_weights = torch.exp(attn_weights - max_score)
67+
exp_values = torch.bmm(exp_weights, value)
68+
max_score = max_score.squeeze(-1)
69+
return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
70+
summarizer: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
71+
72+
def chunk_scanner(chunk_idx: int) -> AttnChunk:
73+
key_chunk = dynamic_slice(
74+
key,
75+
(0, chunk_idx, 0),
76+
(batch_x_heads, key_chunk_size, k_channels_per_head)
77+
)
78+
value_chunk = dynamic_slice(
79+
value,
80+
(0, chunk_idx, 0),
81+
(batch_x_heads, key_chunk_size, v_channels_per_head)
82+
)
83+
84+
return summarizer(query, key_chunk, value_chunk)
85+
86+
chunks: List[AttnChunk] = [
87+
chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, key_chunk_size)
88+
]
89+
acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks)))
90+
chunk_values, chunk_weights, chunk_max = acc_chunk
91+
92+
global_max, _ = torch.max(chunk_max, 0, keepdim=True)
93+
max_diffs = torch.exp(chunk_max - global_max)
94+
chunk_values *= torch.unsqueeze(max_diffs, -1)
95+
chunk_weights *= max_diffs
96+
97+
all_values = chunk_values.sum(dim=0)
98+
all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0)
99+
return all_values / all_weights
100+
101+
class ScannedChunk(NamedTuple):
102+
chunk_idx: int
103+
attn_chunk: AttnChunk
104+
105+
def efficient_dot_product_attention(
106+
query: Tensor,
107+
key: Tensor,
108+
value: Tensor,
109+
query_chunk_size=1024,
110+
key_chunk_size: Optional[int] = None,
111+
use_checkpoint=True,
112+
):
113+
"""Computes efficient dot-product attention given query, key, and value.
114+
This is efficient version of attention presented in
115+
https://arxiv.org/abs/2112.05682v2 which comes with O(sqrt(n)) memory requirements.
116+
Args:
117+
query: queries for calculating attention with shape of
118+
`[batch * num_heads, tokens, channels_per_head]`.
119+
key: keys for calculating attention with shape of
120+
`[batch * num_heads, tokens, channels_per_head]`.
121+
value: values to be used in attention with shape of
122+
`[batch * num_heads, tokens, channels_per_head]`.
123+
query_chunk_size: int: query chunks size
124+
key_chunk_size: int: key chunks size
125+
use_checkpoint: bool: whether to use checkpointing (recommended True for training, False for inference)
126+
Returns:
127+
Output of shape `[batch * num_heads, query_tokens, channels_per_head]`.
128+
"""
129+
batch_x_heads, q_tokens, q_channels_per_head = query.shape
130+
131+
def chunk_scanner(chunk_idx: int) -> Tensor:
132+
query_chunk = dynamic_slice(
133+
query,
134+
(0, chunk_idx, 0),
135+
(batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head)
136+
)
137+
138+
return _query_chunk_attention(
139+
query_chunk,
140+
key,
141+
value,
142+
key_chunk_size=key_chunk_size,
143+
use_checkpoint=use_checkpoint,
144+
)
145+
146+
res = torch.cat([
147+
chunk_scanner(i * query_chunk_size) for i in range(math.ceil(q_tokens / query_chunk_size))
148+
], dim=1)
149+
return res

0 commit comments

Comments
 (0)