diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index 98173cb8a406..1e8453399dff 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional, Union +from .sub_quadratic_attention import efficient_dot_product_attention import torch import torch.nn.functional as F -from torch import nn +from torch import nn, Tensor from ..utils.import_utils import is_xformers_available @@ -145,6 +146,29 @@ def set_attention_slice(self, slice_size): processor = CrossAttnProcessor() self.set_processor(processor) + + def set_subquadratic_attention( + self, + query_chunk_size = 1024, + kv_chunk_size: Optional[int] = None, + kv_chunk_size_min: Optional[int] = None, + chunk_threshold_bytes: Optional[int] = None, + ): + r""" + Args: + query_chunk_size (`int`, *optional*, defaults to `1024`) + kv_chunk_size (`int`, *optional*, defaults to `None`): if None, sqrt(key_tokens) is used. + kv_chunk_size_min (`int`, *optional*, defaults to `None`): only considered when `kv_chunk_size is None`. changes `sqrt(key_tokens)` into `max(sqrt(key_tokens), kv_chunk_size_min)`, to ensure our chunk sizes don't get too small (smaller chunks = more chunks = less concurrent work done). + chunk_threshold_bytes (`int`, *optional*, defaults to `None`): if defined: only bother chunking if the self-attn matmul would allocate more bytes than this. whenever we can fit traditional attention into memory: we should prefer to do so, as the unchunked algorithm is faster. + """ + processor = SubQuadraticCrossAttnProcessor( + query_chunk_size=query_chunk_size, + kv_chunk_size=kv_chunk_size, + kv_chunk_size_min=kv_chunk_size_min, + chunk_threshold_bytes=chunk_threshold_bytes, + ) + + self.set_processor(processor) def set_processor(self, processor: "AttnProcessor"): self.processor = processor @@ -236,6 +260,94 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No return hidden_states +class SubQuadraticCrossAttnProcessor: + query_chunk_size: int + kv_chunk_size: Optional[int] + kv_chunk_size_min: Optional[int] + chunk_threshold_bytes: Optional[int] + def __init__( + self, + query_chunk_size = 1024, + kv_chunk_size: Optional[int] = None, + kv_chunk_size_min: Optional[int] = None, + chunk_threshold_bytes: Optional[int] = None, + ): + r""" + Args: + query_chunk_size (`int`, *optional*, defaults to `1024`) + kv_chunk_size (`int`, *optional*, defaults to `None`): if None, sqrt(key_tokens) is used. + kv_chunk_size_min (`int`, *optional*, defaults to `None`): only considered when `kv_chunk_size is None`. changes `sqrt(key_tokens)` into `max(sqrt(key_tokens), kv_chunk_size_min)`, to ensure our chunk sizes don't get too small (smaller chunks = more chunks = less concurrent work done). + chunk_threshold_bytes (`int`, *optional*, defaults to `None`): if defined: only bother chunking if the self-attn matmul would allocate more bytes than this. whenever we can fit traditional attention into memory: we should prefer to do so, as the unchunked algorithm is faster. + """ + self.query_chunk_size = query_chunk_size + self.kv_chunk_size = kv_chunk_size + self.kv_chunk_size_min = kv_chunk_size_min + self.chunk_threshold_bytes = chunk_threshold_bytes + + def __call__( + self, + attn: CrossAttention, + hidden_states: Tensor, + encoder_hidden_states: Optional[Tensor]=None, + attention_mask: Optional[Tensor]=None, + ): + encoder_hidden_states = hidden_states if encoder_hidden_states is None else encoder_hidden_states + + assert attention_mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor." + # I don't know what test case can be used to determine whether softmax is computed at sufficient bit-width, + # but sub-quadratic attention has a pretty bespoke softmax (defers computation of the denominator) so this needs some thought. + assert not attn.upcast_softmax or torch.finfo(hidden_states.dtype).bits >= 32, "upcast_softmax was requested, but is not implemented" + + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = query.unflatten(-1, (attn.heads, -1)).transpose(1,2).flatten(end_dim=1) + key_t = key.transpose(1,2).unflatten(1, (attn.heads, -1)).flatten(end_dim=1) + del key + value = value.unflatten(-1, (attn.heads, -1)).transpose(1,2).flatten(end_dim=1) + + dtype = query.dtype + # TODO: do we still need to do *everything* in float32, given how we delay the division? + # TODO: do we need to support upcast_softmax too? SD 2.1 seems to work without it + if attn.upcast_attention: + query = query.float() + key_t = key_t.float() + + bytes_per_token = torch.finfo(query.dtype).bits//8 + batch_x_heads, q_tokens, _ = query.shape + _, _, k_tokens = key_t.shape + qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens + + query_chunk_size = self.query_chunk_size + kv_chunk_size = self.kv_chunk_size + + if self.chunk_threshold_bytes is not None and qk_matmul_size_bytes <= self.chunk_threshold_bytes: + # the big matmul fits into our memory limit; do everything in 1 chunk, + # i.e. send it down the unchunked fast-path + query_chunk_size = q_tokens + kv_chunk_size = k_tokens + + hidden_states = efficient_dot_product_attention( + query, + key_t, + value, + query_chunk_size=query_chunk_size, + kv_chunk_size=kv_chunk_size, + kv_chunk_size_min=self.kv_chunk_size_min, + use_checkpoint=attn.training, + ) + + hidden_states = hidden_states.to(dtype) + + hidden_states = hidden_states.unflatten(0, (-1, attn.heads)).transpose(1,2).flatten(start_dim=2) + + out_proj, dropout = attn.to_out + hidden_states = out_proj(hidden_states) + hidden_states = dropout(hidden_states) + + return hidden_states + class CrossAttnAddedKVProcessor: def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): diff --git a/src/diffusers/models/sub_quadratic_attention.py b/src/diffusers/models/sub_quadratic_attention.py new file mode 100644 index 000000000000..a2e8aea513f5 --- /dev/null +++ b/src/diffusers/models/sub_quadratic_attention.py @@ -0,0 +1,194 @@ +# original source: +# https://github.com/AminRezaei0x443/memory-efficient-attention/blob/1bc0d9e6ac5f82ea43a375135c4e1d3896ee1694/memory_efficient_attention/attention_torch.py +# license: +# unspecified +# credit: +# Amin Rezaei (original author) +# Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks) +# implementation of: +# Self-attention Does Not Need O(n2) Memory": +# https://arxiv.org/abs/2112.05682v2 + +from functools import partial +import torch +from torch import Tensor +from torch.utils.checkpoint import checkpoint +import math +from typing import Optional, NamedTuple, Protocol, List +from ..utils.dynamic_slice import dynamic_slice + +class AttnChunk(NamedTuple): + exp_values: Tensor + exp_weights_sum: Tensor + max_score: Tensor + +class SummarizeChunk(Protocol): + @staticmethod + def __call__( + query: Tensor, + key_t: Tensor, + value: Tensor, + ) -> AttnChunk: ... + +class ComputeQueryChunkAttn(Protocol): + @staticmethod + def __call__( + query: Tensor, + key_t: Tensor, + value: Tensor, + ) -> Tensor: ... + +def _summarize_chunk( + query: Tensor, + key_t: Tensor, + value: Tensor, + scale: float, +) -> AttnChunk: + attn_weights = torch.baddbmm( + torch.empty(1, 1, 1, device=query.device, dtype=query.dtype), + query, + key_t, + alpha=scale, + beta=0, + ) + max_score, _ = torch.max(attn_weights, -1, keepdim=True) + max_score = max_score.detach() + exp_weights = torch.exp(attn_weights - max_score) + exp_values = torch.bmm(exp_weights, value) + max_score = max_score.squeeze(-1) + return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score) + +def _query_chunk_attention( + query: Tensor, + key_t: Tensor, + value: Tensor, + summarize_chunk: SummarizeChunk, + kv_chunk_size: int, +) -> Tensor: + batch_x_heads, k_channels_per_head, k_tokens = key_t.shape + _, _, v_channels_per_head = value.shape + + def chunk_scanner(chunk_idx: int) -> AttnChunk: + key_chunk = dynamic_slice( + key_t, + (0, 0, chunk_idx), + (batch_x_heads, k_channels_per_head, kv_chunk_size) + ) + value_chunk = dynamic_slice( + value, + (0, chunk_idx, 0), + (batch_x_heads, kv_chunk_size, v_channels_per_head) + ) + return summarize_chunk(query, key_chunk, value_chunk) + + chunks: List[AttnChunk] = [ + chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size) + ] + acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks))) + chunk_values, chunk_weights, chunk_max = acc_chunk + + global_max, _ = torch.max(chunk_max, 0, keepdim=True) + max_diffs = torch.exp(chunk_max - global_max) + chunk_values *= torch.unsqueeze(max_diffs, -1) + chunk_weights *= max_diffs + + all_values = chunk_values.sum(dim=0) + all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0) + return all_values / all_weights + +# TODO: refactor CrossAttention#get_attention_scores to share code with this +def _get_attention_scores_no_kv_chunking( + query: Tensor, + key_t: Tensor, + value: Tensor, + scale: float, +) -> Tensor: + attn_scores = torch.baddbmm( + torch.empty(1, 1, 1, device=query.device, dtype=query.dtype), + query, + key_t, + alpha=scale, + beta=0, + ) + attn_probs = attn_scores.softmax(dim=-1) + del attn_scores + hidden_states_slice = torch.bmm(attn_probs, value) + return hidden_states_slice + +class ScannedChunk(NamedTuple): + chunk_idx: int + attn_chunk: AttnChunk + +def efficient_dot_product_attention( + query: Tensor, + key_t: Tensor, + value: Tensor, + query_chunk_size=1024, + kv_chunk_size: Optional[int] = None, + kv_chunk_size_min: Optional[int] = None, + use_checkpoint=True, +): + """Computes efficient dot-product attention given query, transposed key, and value. + This is efficient version of attention presented in + https://arxiv.org/abs/2112.05682v2 which comes with O(sqrt(n)) memory requirements. + Args: + query: queries for calculating attention with shape of + `[batch * num_heads, tokens, channels_per_head]`. + key_t: keys for calculating attention with shape of + `[batch * num_heads, channels_per_head, tokens]`. + value: values to be used in attention with shape of + `[batch * num_heads, tokens, channels_per_head]`. + query_chunk_size: int: query chunks size + kv_chunk_size: Optional[int]: key/value chunks size. if None: defaults to sqrt(key_tokens) + kv_chunk_size_min: Optional[int]: key/value minimum chunk size. only considered when kv_chunk_size is None. changes `sqrt(key_tokens)` into `max(sqrt(key_tokens), kv_chunk_size_min)`, to ensure our chunk sizes don't get too small (smaller chunks = more chunks = less concurrent work done). + use_checkpoint: bool: whether to use checkpointing (recommended True for training, False for inference) + Returns: + Output of shape `[batch * num_heads, query_tokens, channels_per_head]`. + """ + batch_x_heads, q_tokens, q_channels_per_head = query.shape + _, _, k_tokens = key_t.shape + scale = q_channels_per_head ** -0.5 + + kv_chunk_size = min(kv_chunk_size or int(math.sqrt(k_tokens)), k_tokens) + if kv_chunk_size_min is not None: + kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min) + + def get_query_chunk(chunk_idx: int) -> Tensor: + return dynamic_slice( + query, + (0, chunk_idx, 0), + (batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head) + ) + + summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale) + summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk + compute_query_chunk_attn: ComputeQueryChunkAttn = partial( + _get_attention_scores_no_kv_chunking, + scale=scale + ) if k_tokens <= kv_chunk_size else ( + # fast-path for when there's just 1 key-value chunk per query chunk (this is just sliced attention btw) + partial( + _query_chunk_attention, + kv_chunk_size=kv_chunk_size, + summarize_chunk=summarize_chunk, + ) + ) + + if q_tokens <= query_chunk_size: + # fast-path for when there's just 1 query chunk + return compute_query_chunk_attn( + query=query, + key_t=key_t, + value=value, + ) + + # TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance, + # and pass slices to be mutated, instead of torch.cat()ing the returned slices + res = torch.cat([ + compute_query_chunk_attn( + query=get_query_chunk(i * query_chunk_size), + key_t=key_t, + value=value, + ) for i in range(math.ceil(q_tokens / query_chunk_size)) + ], dim=1) + return res diff --git a/src/diffusers/utils/dynamic_slice.py b/src/diffusers/utils/dynamic_slice.py new file mode 100644 index 000000000000..046678bb51f4 --- /dev/null +++ b/src/diffusers/utils/dynamic_slice.py @@ -0,0 +1,10 @@ +from torch import Tensor +from typing import List + +def dynamic_slice( + x: Tensor, + starts: List[int], + sizes: List[int], +) -> Tensor: + slicing = [slice(start, start + size) for start, size in zip(starts, sizes)] + return x[slicing] \ No newline at end of file