Skip to content

Commit 1c4f107

Browse files
committed
move kv_chunk_size_min concern to callsite, since if caller knows final kv_chunk_size: they can notice when no chunking would happen at all, and use fast-path. note: there's a question of whether that concern belongs *inside* the algorithm. but it'd feel weird for chunked attention to have a no-chunking-at-all branch.
1 parent fbd3ac7 commit 1c4f107

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

src/diffusers/models/cross_attention.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import torch
1818
import torch.nn.functional as F
1919
from torch import nn, Tensor
20+
import math
2021

2122
from ..utils.import_utils import is_xformers_available
2223

@@ -318,14 +319,19 @@ def __call__(
318319
_, k_tokens, _ = key.shape
319320
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
320321

321-
if self.chunk_threshold_bytes is None or qk_matmul_size_bytes > self.chunk_threshold_bytes:
322+
kv_chunk_size = min(self.kv_chunk_size or int(math.sqrt(k_tokens)), k_tokens)
323+
if self.kv_chunk_size_min is not None:
324+
kv_chunk_size = max(kv_chunk_size, self.kv_chunk_size_min)
325+
326+
uses_chunking = q_tokens > self.query_chunk_size or k_tokens > kv_chunk_size
327+
328+
if uses_chunking and (self.chunk_threshold_bytes is None or qk_matmul_size_bytes > self.chunk_threshold_bytes):
322329
hidden_states = efficient_dot_product_attention(
323330
query,
324331
key,
325332
value,
326333
query_chunk_size=self.query_chunk_size,
327-
kv_chunk_size=self.kv_chunk_size,
328-
kv_chunk_size_min=self.kv_chunk_size_min,
334+
kv_chunk_size=kv_chunk_size,
329335
use_checkpoint=attn.training,
330336
)
331337
else:

src/diffusers/models/sub_quadratic_attention.py

-6
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,6 @@ def efficient_dot_product_attention(
125125
value: Tensor,
126126
query_chunk_size=1024,
127127
kv_chunk_size: Optional[int] = None,
128-
kv_chunk_size_min: Optional[int] = None,
129128
use_checkpoint=True,
130129
):
131130
"""Computes efficient dot-product attention given query, key, and value.
@@ -140,7 +139,6 @@ def efficient_dot_product_attention(
140139
`[batch * num_heads, tokens, channels_per_head]`.
141140
query_chunk_size: int: query chunks size
142141
kv_chunk_size: Optional[int]: key/value chunks size. if None: defaults to sqrt(key_tokens)
143-
kv_chunk_size_min: Optional[int]: key/value minimum chunk size. only considered when kv_chunk_size is None. changes `sqrt(key_tokens)` into `max(sqrt(key_tokens), kv_chunk_size_min)`, to ensure our chunk sizes don't get too small (smaller chunks = more chunks = less concurrent work done).
144142
use_checkpoint: bool: whether to use checkpointing (recommended True for training, False for inference)
145143
Returns:
146144
Output of shape `[batch * num_heads, query_tokens, channels_per_head]`.
@@ -149,10 +147,6 @@ def efficient_dot_product_attention(
149147
_, k_tokens, _ = key.shape
150148
scale = q_channels_per_head ** -0.5
151149

152-
kv_chunk_size = min(kv_chunk_size or int(math.sqrt(k_tokens)), k_tokens)
153-
if kv_chunk_size_min is not None:
154-
kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min)
155-
156150
def get_query_chunk(chunk_idx: int) -> Tensor:
157151
return dynamic_slice(
158152
query,

0 commit comments

Comments
 (0)