Skip to content

Commit c295e4a

Browse files
Merge pull request #6055 from brkirch/sub-quad_attn_opt
Add Birch-san's sub-quadratic attention implementation
2 parents 1a5b86a + c18add6 commit c295e4a

7 files changed

+348
-39
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al
141141
- Ideas for optimizations - https://github.com/basujindal/stable-diffusion
142142
- Cross Attention layer optimization - Doggettx - https://github.com/Doggettx/stable-diffusion, original idea for prompt editing.
143143
- Cross Attention layer optimization - InvokeAI, lstein - https://github.com/invoke-ai/InvokeAI (originally http://github.com/lstein/stable-diffusion)
144+
- Sub-quadratic Cross Attention layer optimization - Alex Birch (https://github.com/Birch-san/diffusers/pull/1), Amin Rezaei (https://github.com/AminRezaei0x443/memory-efficient-attention)
144145
- Textual Inversion - Rinon Gal - https://github.com/rinongal/textual_inversion (we're not using his code, but we are using his ideas).
145146
- Idea for SD upscale - https://github.com/jquesnelle/txt2imghd
146147
- Noise generation for outpainting mk2 - https://github.com/parlance-zz/g-diffuser-bot

html/licenses.html

+28-1
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ <h2><a href="https://github.com/pharmapsychotic/clip-interrogator/blob/main/LICE
184184
</pre>
185185

186186
<h2><a href="https://github.com/JingyunLiang/SwinIR/blob/main/LICENSE">SwinIR</a></h2>
187-
<small>Code added by contirubtors, most likely copied from this repository.</small>
187+
<small>Code added by contributors, most likely copied from this repository.</small>
188188

189189
<pre>
190190
Apache License
@@ -390,3 +390,30 @@ <h2><a href="https://github.com/JingyunLiang/SwinIR/blob/main/LICENSE">SwinIR</a
390390
limitations under the License.
391391
</pre>
392392

393+
<h2><a href="https://github.com/AminRezaei0x443/memory-efficient-attention/blob/main/LICENSE">Memory Efficient Attention</a></h2>
394+
<small>The sub-quadratic cross attention optimization uses modified code from the Memory Efficient Attention package that Alex Birch optimized for 3D tensors. This license is updated to reflect that.</small>
395+
<pre>
396+
MIT License
397+
398+
Copyright (c) 2023 Alex Birch
399+
Copyright (c) 2023 Amin Rezaei
400+
401+
Permission is hereby granted, free of charge, to any person obtaining a copy
402+
of this software and associated documentation files (the "Software"), to deal
403+
in the Software without restriction, including without limitation the rights
404+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
405+
copies of the Software, and to permit persons to whom the Software is
406+
furnished to do so, subject to the following conditions:
407+
408+
The above copyright notice and this permission notice shall be included in all
409+
copies or substantial portions of the Software.
410+
411+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
412+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
413+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
414+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
415+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
416+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
417+
SOFTWARE.
418+
</pre>
419+

modules/sd_hijack.py

+9-12
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
from modules.shared import cmd_opts
88
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
99

10-
from modules.sd_hijack_optimizations import invokeAI_mps_available
11-
1210
import ldm.modules.attention
1311
import ldm.modules.diffusionmodules.model
1412
import ldm.modules.diffusionmodules.openaimodel
@@ -43,20 +41,19 @@ def apply_optimizations():
4341
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
4442
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
4543
optimization_method = 'xformers'
44+
elif cmd_opts.opt_sub_quad_attention:
45+
print("Applying sub-quadratic cross attention optimization.")
46+
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward
47+
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sub_quad_attnblock_forward
48+
optimization_method = 'sub-quadratic'
4649
elif cmd_opts.opt_split_attention_v1:
4750
print("Applying v1 cross attention optimization.")
4851
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
4952
optimization_method = 'V1'
50-
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()):
51-
if not invokeAI_mps_available and shared.device.type == 'mps':
52-
print("The InvokeAI cross attention optimization for MPS requires the psutil package which is not installed.")
53-
print("Applying v1 cross attention optimization.")
54-
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
55-
optimization_method = 'V1'
56-
else:
57-
print("Applying cross attention optimization (InvokeAI).")
58-
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
59-
optimization_method = 'InvokeAI'
53+
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not cmd_opts.opt_split_attention and not torch.cuda.is_available()):
54+
print("Applying cross attention optimization (InvokeAI).")
55+
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
56+
optimization_method = 'InvokeAI'
6057
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
6158
print("Applying cross attention optimization (Doggettx).")
6259
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward

modules/sd_hijack_optimizations.py

+100-25
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import math
22
import sys
33
import traceback
4-
import importlib
4+
import psutil
55

66
import torch
77
from torch import einsum
@@ -12,6 +12,8 @@
1212
from modules import shared
1313
from modules.hypernetworks import hypernetwork
1414

15+
from .sub_quadratic_attention import efficient_dot_product_attention
16+
1517

1618
if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
1719
try:
@@ -22,6 +24,19 @@
2224
print(traceback.format_exc(), file=sys.stderr)
2325

2426

27+
def get_available_vram():
28+
if shared.device.type == 'cuda':
29+
stats = torch.cuda.memory_stats(shared.device)
30+
mem_active = stats['active_bytes.all.current']
31+
mem_reserved = stats['reserved_bytes.all.current']
32+
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
33+
mem_free_torch = mem_reserved - mem_active
34+
mem_free_total = mem_free_cuda + mem_free_torch
35+
return mem_free_total
36+
else:
37+
return psutil.virtual_memory().available
38+
39+
2540
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
2641
def split_cross_attention_forward_v1(self, x, context=None, mask=None):
2742
h = self.heads
@@ -76,12 +91,7 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
7691

7792
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
7893

79-
stats = torch.cuda.memory_stats(q.device)
80-
mem_active = stats['active_bytes.all.current']
81-
mem_reserved = stats['reserved_bytes.all.current']
82-
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
83-
mem_free_torch = mem_reserved - mem_active
84-
mem_free_total = mem_free_cuda + mem_free_torch
94+
mem_free_total = get_available_vram()
8595

8696
gb = 1024 ** 3
8797
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
@@ -118,19 +128,8 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
118128
return self.to_out(r2)
119129

120130

121-
def check_for_psutil():
122-
try:
123-
spec = importlib.util.find_spec('psutil')
124-
return spec is not None
125-
except ModuleNotFoundError:
126-
return False
127-
128-
invokeAI_mps_available = check_for_psutil()
129-
130131
# -- Taken from https://github.com/invoke-ai/InvokeAI and modified --
131-
if invokeAI_mps_available:
132-
import psutil
133-
mem_total_gb = psutil.virtual_memory().total // (1 << 30)
132+
mem_total_gb = psutil.virtual_memory().total // (1 << 30)
134133

135134
def einsum_op_compvis(q, k, v):
136135
s = einsum('b i d, b j d -> b i j', q, k)
@@ -215,6 +214,71 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
215214

216215
# -- End of code from https://github.com/invoke-ai/InvokeAI --
217216

217+
218+
# Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1
219+
# The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface
220+
def sub_quad_attention_forward(self, x, context=None, mask=None):
221+
assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor."
222+
223+
h = self.heads
224+
225+
q = self.to_q(x)
226+
context = default(context, x)
227+
228+
context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
229+
k = self.to_k(context_k)
230+
v = self.to_v(context_v)
231+
del context, context_k, context_v, x
232+
233+
q = q.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
234+
k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
235+
v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
236+
237+
x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
238+
239+
x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2)
240+
241+
out_proj, dropout = self.to_out
242+
x = out_proj(x)
243+
x = dropout(x)
244+
245+
return x
246+
247+
def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True):
248+
bytes_per_token = torch.finfo(q.dtype).bits//8
249+
batch_x_heads, q_tokens, _ = q.shape
250+
_, k_tokens, _ = k.shape
251+
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
252+
253+
if chunk_threshold is None:
254+
chunk_threshold_bytes = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7)
255+
elif chunk_threshold == 0:
256+
chunk_threshold_bytes = None
257+
else:
258+
chunk_threshold_bytes = int(0.01 * chunk_threshold * get_available_vram())
259+
260+
if kv_chunk_size_min is None and chunk_threshold_bytes is not None:
261+
kv_chunk_size_min = chunk_threshold_bytes // (batch_x_heads * bytes_per_token * (k.shape[2] + v.shape[2]))
262+
elif kv_chunk_size_min == 0:
263+
kv_chunk_size_min = None
264+
265+
if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
266+
# the big matmul fits into our memory limit; do everything in 1 chunk,
267+
# i.e. send it down the unchunked fast-path
268+
query_chunk_size = q_tokens
269+
kv_chunk_size = k_tokens
270+
271+
return efficient_dot_product_attention(
272+
q,
273+
k,
274+
v,
275+
query_chunk_size=q_chunk_size,
276+
kv_chunk_size=kv_chunk_size,
277+
kv_chunk_size_min = kv_chunk_size_min,
278+
use_checkpoint=use_checkpoint,
279+
)
280+
281+
218282
def xformers_attention_forward(self, x, context=None, mask=None):
219283
h = self.heads
220284
q_in = self.to_q(x)
@@ -252,12 +316,7 @@ def cross_attention_attnblock_forward(self, x):
252316

253317
h_ = torch.zeros_like(k, device=q.device)
254318

255-
stats = torch.cuda.memory_stats(q.device)
256-
mem_active = stats['active_bytes.all.current']
257-
mem_reserved = stats['reserved_bytes.all.current']
258-
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
259-
mem_free_torch = mem_reserved - mem_active
260-
mem_free_total = mem_free_cuda + mem_free_torch
319+
mem_free_total = get_available_vram()
261320

262321
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
263322
mem_required = tensor_size * 2.5
@@ -312,3 +371,19 @@ def xformers_attnblock_forward(self, x):
312371
return x + out
313372
except NotImplementedError:
314373
return cross_attention_attnblock_forward(self, x)
374+
375+
def sub_quad_attnblock_forward(self, x):
376+
h_ = x
377+
h_ = self.norm(h_)
378+
q = self.q(h_)
379+
k = self.k(h_)
380+
v = self.v(h_)
381+
b, c, h, w = q.shape
382+
q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
383+
q = q.contiguous()
384+
k = k.contiguous()
385+
v = v.contiguous()
386+
out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
387+
out = rearrange(out, 'b (h w) c -> b c h w', h=h)
388+
out = self.proj_out(out)
389+
return x + out

modules/shared.py

+4
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@
5656
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
5757
parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything")
5858
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.")
59+
parser.add_argument("--opt-sub-quad-attention", action='store_true', help="enable memory efficient sub-quadratic cross-attention layer optimization")
60+
parser.add_argument("--sub-quad-q-chunk-size", type=int, help="query chunk size for the sub-quadratic cross-attention layer optimization to use", default=1024)
61+
parser.add_argument("--sub-quad-kv-chunk-size", type=int, help="kv chunk size for the sub-quadratic cross-attention layer optimization to use", default=None)
62+
parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage of VRAM threshold for the sub-quadratic cross-attention layer optimization to use chunking", default=None)
5963
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
6064
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
6165
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")

0 commit comments

Comments
 (0)