Skip to content

Commit 957891c

Browse files
committed
1 parent fa9bc02 commit 957891c

File tree

2 files changed

+290
-2
lines changed

2 files changed

+290
-2
lines changed

src/diffusers/models/attention_processor.py

+93-2
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,10 @@
2424
from ..utils.import_utils import is_xformers_available
2525
from ..utils.torch_utils import maybe_allow_in_graph
2626
from .lora import LoRACompatibleLinear, LoRALinearLayer
27-
27+
from .sub_quadratic_attention import efficient_dot_product_attention as attn_subquad
2828

2929
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
3030

31-
3231
if is_xformers_available():
3332
import xformers
3433
import xformers.ops
@@ -2408,6 +2407,98 @@ def __call__(
24082407
return hidden_states
24092408

24102409

2410+
class SubQuadraticCrossAttnProcessor:
2411+
query_chunk_size: int
2412+
kv_chunk_size: Optional[int]
2413+
kv_chunk_size_min: Optional[int]
2414+
chunk_threshold_bytes: Optional[int]
2415+
2416+
def __init__(
2417+
self,
2418+
query_chunk_size=1024,
2419+
kv_chunk_size: Optional[int] = None,
2420+
kv_chunk_size_min: Optional[int] = None,
2421+
chunk_threshold_bytes: Optional[int] = None,
2422+
):
2423+
r"""
2424+
Args:
2425+
query_chunk_size (`int`, *optional*, defaults to `1024`)
2426+
kv_chunk_size (`int`, *optional*, defaults to `None`): if None, sqrt(key_tokens) is used.
2427+
kv_chunk_size_min (`int`, *optional*, defaults to `None`): only considered when `kv_chunk_size is None`. changes `sqrt(key_tokens)` into `max(sqrt(key_tokens), kv_chunk_size_min)`, to ensure our chunk sizes don't get too small (smaller chunks = more chunks = less concurrent work done).
2428+
chunk_threshold_bytes (`int`, *optional*, defaults to `None`): if defined: only bother chunking if the self-attn matmul would allocate more bytes than this. whenever we can fit traditional attention into memory: we should prefer to do so, as the unchunked algorithm is faster.
2429+
"""
2430+
self.query_chunk_size = query_chunk_size
2431+
self.kv_chunk_size = kv_chunk_size
2432+
self.kv_chunk_size_min = kv_chunk_size_min
2433+
self.chunk_threshold_bytes = chunk_threshold_bytes
2434+
2435+
def __call__(
2436+
self,
2437+
attn: Attention,
2438+
hidden_states: torch.Tensor,
2439+
encoder_hidden_states: Optional[torch.Tensor] = None,
2440+
attention_mask: Optional[torch.Tensor] = None,
2441+
):
2442+
encoder_hidden_states = hidden_states if encoder_hidden_states is None else encoder_hidden_states
2443+
2444+
assert attention_mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor."
2445+
# I don't know what test case can be used to determine whether softmax is computed at sufficient bit-width,
2446+
# but sub-quadratic attention has a pretty bespoke softmax (defers computation of the denominator) so this needs some thought.
2447+
assert (
2448+
not attn.upcast_softmax or torch.finfo(hidden_states.dtype).bits >= 32
2449+
), "upcast_softmax was requested, but is not implemented"
2450+
2451+
query = attn.to_q(hidden_states)
2452+
key = attn.to_k(encoder_hidden_states)
2453+
value = attn.to_v(encoder_hidden_states)
2454+
2455+
query = query.unflatten(-1, (attn.heads, -1)).transpose(1, 2).flatten(end_dim=1)
2456+
key_t = key.transpose(1, 2).unflatten(1, (attn.heads, -1)).flatten(end_dim=1)
2457+
del key
2458+
value = value.unflatten(-1, (attn.heads, -1)).transpose(1, 2).flatten(end_dim=1)
2459+
2460+
dtype = query.dtype
2461+
# TODO: do we still need to do *everything* in float32, given how we delay the division?
2462+
# TODO: do we need to support upcast_softmax too? SD 2.1 seems to work without it
2463+
if attn.upcast_attention:
2464+
query = query.float()
2465+
key_t = key_t.float()
2466+
2467+
bytes_per_token = torch.finfo(query.dtype).bits // 8
2468+
batch_x_heads, q_tokens, _ = query.shape
2469+
_, _, k_tokens = key_t.shape
2470+
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
2471+
2472+
query_chunk_size = self.query_chunk_size
2473+
kv_chunk_size = self.kv_chunk_size
2474+
2475+
if self.chunk_threshold_bytes is not None and qk_matmul_size_bytes <= self.chunk_threshold_bytes:
2476+
# the big matmul fits into our memory limit; do everything in 1 chunk,
2477+
# i.e. send it down the unchunked fast-path
2478+
query_chunk_size = q_tokens
2479+
kv_chunk_size = k_tokens
2480+
2481+
hidden_states = attn_subquad(
2482+
query,
2483+
key_t,
2484+
value,
2485+
query_chunk_size=query_chunk_size,
2486+
kv_chunk_size=kv_chunk_size,
2487+
kv_chunk_size_min=self.kv_chunk_size_min,
2488+
use_checkpoint=attn.training,
2489+
)
2490+
2491+
hidden_states = hidden_states.to(dtype)
2492+
2493+
hidden_states = hidden_states.unflatten(0, (-1, attn.heads)).transpose(1, 2).flatten(start_dim=2)
2494+
2495+
out_proj, dropout = attn.to_out
2496+
hidden_states = out_proj(hidden_states)
2497+
hidden_states = dropout(hidden_states)
2498+
2499+
return hidden_states
2500+
2501+
24112502
LORA_ATTENTION_PROCESSORS = (
24122503
LoRAAttnProcessor,
24132504
LoRAAttnProcessor2_0,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
# original source:
2+
# https://github.com/AminRezaei0x443/memory-efficient-attention/blob/1bc0d9e6ac5f82ea43a375135c4e1d3896ee1694/memory_efficient_attention/attention_torch.py
3+
# license:
4+
# unspecified
5+
# credit:
6+
# Amin Rezaei (original author)
7+
# Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks)
8+
# implementation of:
9+
# Self-attention Does Not Need O(n2) Memory":
10+
# https://arxiv.org/abs/2112.05682v2
11+
12+
import math
13+
import torch
14+
from ..utils.dynamic_slice import dynamic_slice
15+
from functools import partial
16+
from torch import Tensor
17+
from torch.utils.checkpoint import checkpoint
18+
from typing import Optional, NamedTuple, Protocol, List
19+
20+
21+
class AttnChunk(NamedTuple):
22+
exp_values: Tensor
23+
exp_weights_sum: Tensor
24+
max_score: Tensor
25+
26+
27+
class SummarizeChunk(Protocol):
28+
@staticmethod
29+
def __call__(
30+
query: Tensor,
31+
key_t: Tensor,
32+
value: Tensor,
33+
) -> AttnChunk:
34+
...
35+
36+
37+
class ComputeQueryChunkAttn(Protocol):
38+
@staticmethod
39+
def __call__(
40+
query: Tensor,
41+
key_t: Tensor,
42+
value: Tensor,
43+
) -> Tensor:
44+
...
45+
46+
47+
def _summarize_chunk(
48+
query: Tensor,
49+
key_t: Tensor,
50+
value: Tensor,
51+
scale: float,
52+
) -> AttnChunk:
53+
attn_weights = torch.baddbmm(
54+
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
55+
query,
56+
key_t,
57+
alpha=scale,
58+
beta=0,
59+
)
60+
max_score, _ = torch.max(attn_weights, -1, keepdim=True)
61+
max_score = max_score.detach()
62+
exp_weights = torch.exp(attn_weights - max_score)
63+
exp_values = torch.bmm(exp_weights, value)
64+
max_score = max_score.squeeze(-1)
65+
return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
66+
67+
68+
def _query_chunk_attention(
69+
query: Tensor,
70+
key_t: Tensor,
71+
value: Tensor,
72+
summarize_chunk: SummarizeChunk,
73+
kv_chunk_size: int,
74+
) -> Tensor:
75+
batch_x_heads, k_channels_per_head, k_tokens = key_t.shape
76+
_, _, v_channels_per_head = value.shape
77+
78+
def chunk_scanner(chunk_idx: int) -> AttnChunk:
79+
key_chunk = dynamic_slice(key_t, (0, 0, chunk_idx), (batch_x_heads, k_channels_per_head, kv_chunk_size))
80+
value_chunk = dynamic_slice(value, (0, chunk_idx, 0), (batch_x_heads, kv_chunk_size, v_channels_per_head))
81+
return summarize_chunk(query, key_chunk, value_chunk)
82+
83+
chunks: List[AttnChunk] = [chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size)]
84+
acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks)))
85+
chunk_values, chunk_weights, chunk_max = acc_chunk
86+
87+
global_max, _ = torch.max(chunk_max, 0, keepdim=True)
88+
max_diffs = torch.exp(chunk_max - global_max)
89+
chunk_values *= torch.unsqueeze(max_diffs, -1)
90+
chunk_weights *= max_diffs
91+
92+
all_values = chunk_values.sum(dim=0)
93+
all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0)
94+
return all_values / all_weights
95+
96+
97+
# TODO: refactor CrossAttention#get_attention_scores to share code with this
98+
def _get_attention_scores_no_kv_chunking(
99+
query: Tensor,
100+
key_t: Tensor,
101+
value: Tensor,
102+
scale: float,
103+
) -> Tensor:
104+
attn_scores = torch.baddbmm(
105+
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
106+
query,
107+
key_t,
108+
alpha=scale,
109+
beta=0,
110+
)
111+
attn_probs = attn_scores.softmax(dim=-1)
112+
del attn_scores
113+
hidden_states_slice = torch.bmm(attn_probs, value)
114+
return hidden_states_slice
115+
116+
117+
class ScannedChunk(NamedTuple):
118+
chunk_idx: int
119+
attn_chunk: AttnChunk
120+
121+
122+
def efficient_dot_product_attention(
123+
query: Tensor,
124+
key_t: Tensor,
125+
value: Tensor,
126+
query_chunk_size=1024,
127+
kv_chunk_size: Optional[int] = None,
128+
kv_chunk_size_min: Optional[int] = None,
129+
use_checkpoint=True,
130+
):
131+
"""Computes efficient dot-product attention given query, transposed key, and value.
132+
This is efficient version of attention presented in
133+
https://arxiv.org/abs/2112.05682v2 which comes with O(sqrt(n)) memory requirements.
134+
Args:
135+
query: queries for calculating attention with shape of
136+
`[batch * num_heads, tokens, channels_per_head]`.
137+
key_t: keys for calculating attention with shape of
138+
`[batch * num_heads, channels_per_head, tokens]`.
139+
value: values to be used in attention with shape of
140+
`[batch * num_heads, tokens, channels_per_head]`.
141+
query_chunk_size: int: query chunks size
142+
kv_chunk_size: Optional[int]: key/value chunks size. if None: defaults to sqrt(key_tokens)
143+
kv_chunk_size_min: Optional[int]: key/value minimum chunk size. only considered when kv_chunk_size is None. changes `sqrt(key_tokens)` into `max(sqrt(key_tokens), kv_chunk_size_min)`, to ensure our chunk sizes don't get too small (smaller chunks = more chunks = less concurrent work done).
144+
use_checkpoint: bool: whether to use checkpointing (recommended True for training, False for inference)
145+
Returns:
146+
Output of shape `[batch * num_heads, query_tokens, channels_per_head]`.
147+
"""
148+
batch_x_heads, q_tokens, q_channels_per_head = query.shape
149+
_, _, k_tokens = key_t.shape
150+
scale = q_channels_per_head**-0.5
151+
152+
kv_chunk_size = min(kv_chunk_size or int(math.sqrt(k_tokens)), k_tokens)
153+
if kv_chunk_size_min is not None:
154+
kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min)
155+
156+
def get_query_chunk(chunk_idx: int) -> Tensor:
157+
return dynamic_slice(
158+
query, (0, chunk_idx, 0), (batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head)
159+
)
160+
161+
summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale)
162+
summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
163+
compute_query_chunk_attn: ComputeQueryChunkAttn = (
164+
partial(_get_attention_scores_no_kv_chunking, scale=scale)
165+
if k_tokens <= kv_chunk_size
166+
# fast-path for when there's just 1 key-value chunk per query chunk (this is just sliced attention btw)
167+
else (
168+
partial(
169+
_query_chunk_attention,
170+
kv_chunk_size=kv_chunk_size,
171+
summarize_chunk=summarize_chunk,
172+
)
173+
)
174+
)
175+
176+
if q_tokens <= query_chunk_size:
177+
# fast-path for when there's just 1 query chunk
178+
return compute_query_chunk_attn(
179+
query=query,
180+
key_t=key_t,
181+
value=value,
182+
)
183+
184+
# TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance,
185+
# and pass slices to be mutated, instead of torch.cat()ing the returned slices
186+
res = torch.cat(
187+
[
188+
compute_query_chunk_attn(
189+
query=get_query_chunk(i * query_chunk_size),
190+
key_t=key_t,
191+
value=value,
192+
)
193+
for i in range(math.ceil(q_tokens / query_chunk_size))
194+
],
195+
dim=1,
196+
)
197+
return res

0 commit comments

Comments
 (0)