Skip to content

Commit cadcb36

Browse files
committed
Adjust memory usage (and various other changes)
Adjust memory usage, add command line options, make sub-quadratic default if CUDA is unavailable, change sub-quadratic AttnBlock forward to use same implementation as web UI uses for xformers.
1 parent 4bfa22e commit cadcb36

File tree

3 files changed

+40
-58
lines changed

3 files changed

+40
-58
lines changed

modules/sd_hijack.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -38,16 +38,16 @@ def apply_optimizations():
3838
print("Applying xformers cross attention optimization.")
3939
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
4040
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
41-
elif cmd_opts.opt_sub_quad_attention:
42-
print("Applying sub-quadratic cross attention optimization.")
43-
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward
44-
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sub_quad_attnblock_forward
4541
elif cmd_opts.opt_split_attention_v1:
4642
print("Applying v1 cross attention optimization.")
4743
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
48-
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()):
44+
elif cmd_opts.opt_split_attention_invokeai:
4945
print("Applying cross attention optimization (InvokeAI).")
5046
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
47+
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_sub_quad_attention or not torch.cuda.is_available()):
48+
print("Applying sub-quadratic cross attention optimization.")
49+
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward
50+
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sub_quad_attnblock_forward
5151
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
5252
print("Applying cross attention optimization (Doggettx).")
5353
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward

modules/sd_hijack_optimizations.py

+30-51
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,19 @@
2424
print(traceback.format_exc(), file=sys.stderr)
2525

2626

27+
def get_available_vram():
28+
if shared.device.type == 'cuda':
29+
stats = torch.cuda.memory_stats(shared.device)
30+
mem_active = stats['active_bytes.all.current']
31+
mem_reserved = stats['reserved_bytes.all.current']
32+
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
33+
mem_free_torch = mem_reserved - mem_active
34+
mem_free_total = mem_free_cuda + mem_free_torch
35+
return mem_free_total
36+
else:
37+
return psutil.virtual_memory().available
38+
39+
2740
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
2841
def split_cross_attention_forward_v1(self, x, context=None, mask=None):
2942
h = self.heads
@@ -78,12 +91,7 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
7891

7992
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
8093

81-
stats = torch.cuda.memory_stats(q.device)
82-
mem_active = stats['active_bytes.all.current']
83-
mem_reserved = stats['reserved_bytes.all.current']
84-
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
85-
mem_free_torch = mem_reserved - mem_active
86-
mem_free_total = mem_free_cuda + mem_free_torch
94+
mem_free_total = get_available_vram()
8795

8896
gb = 1024 ** 3
8997
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
@@ -207,18 +215,6 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
207215
# -- End of code from https://github.com/invoke-ai/InvokeAI --
208216

209217

210-
def get_available_vram():
211-
if shared.device.type == 'cuda':
212-
stats = torch.cuda.memory_stats(shared.device)
213-
mem_active = stats['active_bytes.all.current']
214-
mem_reserved = stats['reserved_bytes.all.current']
215-
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
216-
mem_free_torch = mem_reserved - mem_active
217-
mem_free_total = mem_free_cuda + mem_free_torch
218-
return mem_free_total
219-
else:
220-
return psutil.virtual_memory().available
221-
222218
# Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1
223219
def sub_quad_attention_forward(self, x, context=None, mask=None):
224220
assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor."
@@ -237,7 +233,7 @@ def sub_quad_attention_forward(self, x, context=None, mask=None):
237233
k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
238234
v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
239235

240-
x = sub_quad_attention(q, k, v, kv_chunk_size_min=None, chunk_threshold_bytes=(get_available_vram() if q.device.type == 'mps' else int(get_available_vram() * 0.95)), use_checkpoint=self.training)
236+
x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold_bytes=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
241237

242238
x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2)
243239

@@ -253,7 +249,12 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_
253249
_, k_tokens, _ = k.shape
254250
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
255251

256-
if kv_chunk_size_min == None:
252+
if chunk_threshold_bytes is None:
253+
chunk_threshold_bytes = int(get_available_vram() * 0.4)
254+
elif chunk_threshold_bytes == 0:
255+
chunk_threshold_bytes = None
256+
257+
if kv_chunk_size_min is None:
257258
kv_chunk_size_min = (chunk_threshold_bytes - batch_x_heads * min(q_chunk_size, q_tokens) * q.shape[2]) // (batch_x_heads * bytes_per_token * (k.shape[2] + v.shape[2]))
258259
elif kv_chunk_size_min == 0:
259260
kv_chunk_size_min = None
@@ -312,12 +313,7 @@ def cross_attention_attnblock_forward(self, x):
312313

313314
h_ = torch.zeros_like(k, device=q.device)
314315

315-
stats = torch.cuda.memory_stats(q.device)
316-
mem_active = stats['active_bytes.all.current']
317-
mem_reserved = stats['reserved_bytes.all.current']
318-
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
319-
mem_free_torch = mem_reserved - mem_active
320-
mem_free_total = mem_free_cuda + mem_free_torch
316+
mem_free_total = get_available_vram()
321317

322318
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
323319
mem_required = tensor_size * 2.5
@@ -373,35 +369,18 @@ def xformers_attnblock_forward(self, x):
373369
except NotImplementedError:
374370
return cross_attention_attnblock_forward(self, x)
375371

376-
# MemoryEfficientAttnBlock forward from https://github.com/Stability-AI/stablediffusion modified to use sub-quadratic attention instead of xformers
377372
def sub_quad_attnblock_forward(self, x):
378373
h_ = x
379374
h_ = self.norm(h_)
380375
q = self.q(h_)
381376
k = self.k(h_)
382377
v = self.v(h_)
383-
384-
# compute attention
385-
B, C, H, W = q.shape
386-
q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
387-
388-
q, k, v = map(
389-
lambda t: t.unsqueeze(3)
390-
.reshape(B, t.shape[1], 1, C)
391-
.permute(0, 2, 1, 3)
392-
.reshape(B * 1, t.shape[1], C)
393-
.contiguous(),
394-
(q, k, v),
395-
)
396-
397-
out = sub_quad_attention(q, k, v, kv_chunk_size_min=0, chunk_threshold_bytes=(get_available_vram() if q.device.type == 'mps' else int(get_available_vram() * 0.95)), use_checkpoint=self.training)
398-
399-
out = (
400-
out.unsqueeze(0)
401-
.reshape(B, 1, out.shape[1], C)
402-
.permute(0, 2, 1, 3)
403-
.reshape(B, out.shape[1], C)
404-
)
405-
out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
378+
b, c, h, w = q.shape
379+
q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
380+
q = q.contiguous()
381+
k = k.contiguous()
382+
v = v.contiguous()
383+
out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold_bytes=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
384+
out = rearrange(out, 'b (h w) c -> b c h w', h=h)
406385
out = self.proj_out(out)
407-
return x+out
386+
return x + out

modules/shared.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,11 @@
5656
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
5757
parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything")
5858
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.")
59-
parser.add_argument("--opt-sub-quad-attention", action='store_true', help="enable memory efficient sub-quadratic cross-attention layer optimization")
60-
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
59+
parser.add_argument("--opt-sub-quad-attention", action='store_true', help="enable memory efficient sub-quadratic cross-attention layer optimization. By default, it's on when cuda is unavailable.")
60+
parser.add_argument("--sub-quad-q-chunk-size", type=int, help="query chunk size for the sub-quadratic cross-attention layer optimization to use", default=1024)
61+
parser.add_argument("--sub-quad-kv-chunk-size", type=int, help="kv chunk size for the sub-quadratic cross-attention layer optimization to use", default=None)
62+
parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the size threshold in bytes for the sub-quadratic cross-attention layer optimization to use chunking", default=None)
63+
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization")
6164
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
6265
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
6366
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)

0 commit comments

Comments
 (0)