|
1 | 1 | import math
|
2 | 2 | import sys
|
3 | 3 | import traceback
|
4 |
| -import importlib |
| 4 | +import psutil |
5 | 5 |
|
6 | 6 | import torch
|
7 | 7 | from torch import einsum
|
|
12 | 12 | from modules import shared
|
13 | 13 | from modules.hypernetworks import hypernetwork
|
14 | 14 |
|
| 15 | +from .sub_quadratic_attention import efficient_dot_product_attention |
| 16 | + |
15 | 17 |
|
16 | 18 | if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
|
17 | 19 | try:
|
|
22 | 24 | print(traceback.format_exc(), file=sys.stderr)
|
23 | 25 |
|
24 | 26 |
|
| 27 | +def get_available_vram(): |
| 28 | + if shared.device.type == 'cuda': |
| 29 | + stats = torch.cuda.memory_stats(shared.device) |
| 30 | + mem_active = stats['active_bytes.all.current'] |
| 31 | + mem_reserved = stats['reserved_bytes.all.current'] |
| 32 | + mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) |
| 33 | + mem_free_torch = mem_reserved - mem_active |
| 34 | + mem_free_total = mem_free_cuda + mem_free_torch |
| 35 | + return mem_free_total |
| 36 | + else: |
| 37 | + return psutil.virtual_memory().available |
| 38 | + |
| 39 | + |
25 | 40 | # see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
|
26 | 41 | def split_cross_attention_forward_v1(self, x, context=None, mask=None):
|
27 | 42 | h = self.heads
|
@@ -76,12 +91,7 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
|
76 | 91 |
|
77 | 92 | r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
78 | 93 |
|
79 |
| - stats = torch.cuda.memory_stats(q.device) |
80 |
| - mem_active = stats['active_bytes.all.current'] |
81 |
| - mem_reserved = stats['reserved_bytes.all.current'] |
82 |
| - mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) |
83 |
| - mem_free_torch = mem_reserved - mem_active |
84 |
| - mem_free_total = mem_free_cuda + mem_free_torch |
| 94 | + mem_free_total = get_available_vram() |
85 | 95 |
|
86 | 96 | gb = 1024 ** 3
|
87 | 97 | tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
|
@@ -118,19 +128,8 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
|
118 | 128 | return self.to_out(r2)
|
119 | 129 |
|
120 | 130 |
|
121 |
| -def check_for_psutil(): |
122 |
| - try: |
123 |
| - spec = importlib.util.find_spec('psutil') |
124 |
| - return spec is not None |
125 |
| - except ModuleNotFoundError: |
126 |
| - return False |
127 |
| - |
128 |
| -invokeAI_mps_available = check_for_psutil() |
129 |
| - |
130 | 131 | # -- Taken from https://github.com/invoke-ai/InvokeAI and modified --
|
131 |
| -if invokeAI_mps_available: |
132 |
| - import psutil |
133 |
| - mem_total_gb = psutil.virtual_memory().total // (1 << 30) |
| 132 | +mem_total_gb = psutil.virtual_memory().total // (1 << 30) |
134 | 133 |
|
135 | 134 | def einsum_op_compvis(q, k, v):
|
136 | 135 | s = einsum('b i d, b j d -> b i j', q, k)
|
@@ -215,6 +214,71 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
|
215 | 214 |
|
216 | 215 | # -- End of code from https://github.com/invoke-ai/InvokeAI --
|
217 | 216 |
|
| 217 | + |
| 218 | +# Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1 |
| 219 | +# The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface |
| 220 | +def sub_quad_attention_forward(self, x, context=None, mask=None): |
| 221 | + assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor." |
| 222 | + |
| 223 | + h = self.heads |
| 224 | + |
| 225 | + q = self.to_q(x) |
| 226 | + context = default(context, x) |
| 227 | + |
| 228 | + context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context) |
| 229 | + k = self.to_k(context_k) |
| 230 | + v = self.to_v(context_v) |
| 231 | + del context, context_k, context_v, x |
| 232 | + |
| 233 | + q = q.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) |
| 234 | + k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) |
| 235 | + v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) |
| 236 | + |
| 237 | + x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training) |
| 238 | + |
| 239 | + x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2) |
| 240 | + |
| 241 | + out_proj, dropout = self.to_out |
| 242 | + x = out_proj(x) |
| 243 | + x = dropout(x) |
| 244 | + |
| 245 | + return x |
| 246 | + |
| 247 | +def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True): |
| 248 | + bytes_per_token = torch.finfo(q.dtype).bits//8 |
| 249 | + batch_x_heads, q_tokens, _ = q.shape |
| 250 | + _, k_tokens, _ = k.shape |
| 251 | + qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens |
| 252 | + |
| 253 | + if chunk_threshold is None: |
| 254 | + chunk_threshold_bytes = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7) |
| 255 | + elif chunk_threshold == 0: |
| 256 | + chunk_threshold_bytes = None |
| 257 | + else: |
| 258 | + chunk_threshold_bytes = int(0.01 * chunk_threshold * get_available_vram()) |
| 259 | + |
| 260 | + if kv_chunk_size_min is None and chunk_threshold_bytes is not None: |
| 261 | + kv_chunk_size_min = chunk_threshold_bytes // (batch_x_heads * bytes_per_token * (k.shape[2] + v.shape[2])) |
| 262 | + elif kv_chunk_size_min == 0: |
| 263 | + kv_chunk_size_min = None |
| 264 | + |
| 265 | + if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes: |
| 266 | + # the big matmul fits into our memory limit; do everything in 1 chunk, |
| 267 | + # i.e. send it down the unchunked fast-path |
| 268 | + query_chunk_size = q_tokens |
| 269 | + kv_chunk_size = k_tokens |
| 270 | + |
| 271 | + return efficient_dot_product_attention( |
| 272 | + q, |
| 273 | + k, |
| 274 | + v, |
| 275 | + query_chunk_size=q_chunk_size, |
| 276 | + kv_chunk_size=kv_chunk_size, |
| 277 | + kv_chunk_size_min = kv_chunk_size_min, |
| 278 | + use_checkpoint=use_checkpoint, |
| 279 | + ) |
| 280 | + |
| 281 | + |
218 | 282 | def xformers_attention_forward(self, x, context=None, mask=None):
|
219 | 283 | h = self.heads
|
220 | 284 | q_in = self.to_q(x)
|
@@ -252,12 +316,7 @@ def cross_attention_attnblock_forward(self, x):
|
252 | 316 |
|
253 | 317 | h_ = torch.zeros_like(k, device=q.device)
|
254 | 318 |
|
255 |
| - stats = torch.cuda.memory_stats(q.device) |
256 |
| - mem_active = stats['active_bytes.all.current'] |
257 |
| - mem_reserved = stats['reserved_bytes.all.current'] |
258 |
| - mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) |
259 |
| - mem_free_torch = mem_reserved - mem_active |
260 |
| - mem_free_total = mem_free_cuda + mem_free_torch |
| 319 | + mem_free_total = get_available_vram() |
261 | 320 |
|
262 | 321 | tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
|
263 | 322 | mem_required = tensor_size * 2.5
|
@@ -312,3 +371,19 @@ def xformers_attnblock_forward(self, x):
|
312 | 371 | return x + out
|
313 | 372 | except NotImplementedError:
|
314 | 373 | return cross_attention_attnblock_forward(self, x)
|
| 374 | + |
| 375 | +def sub_quad_attnblock_forward(self, x): |
| 376 | + h_ = x |
| 377 | + h_ = self.norm(h_) |
| 378 | + q = self.q(h_) |
| 379 | + k = self.k(h_) |
| 380 | + v = self.v(h_) |
| 381 | + b, c, h, w = q.shape |
| 382 | + q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v)) |
| 383 | + q = q.contiguous() |
| 384 | + k = k.contiguous() |
| 385 | + v = v.contiguous() |
| 386 | + out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training) |
| 387 | + out = rearrange(out, 'b (h w) c -> b c h w', h=h) |
| 388 | + out = self.proj_out(out) |
| 389 | + return x + out |
0 commit comments