Skip to content

Commit a3152d8

Browse files
committed
Revert "move kv_chunk_size_min concern to callsite (1c4f107)" because equivalent fast-path for 1 query chunk, 1 kv chunk is already supported inside
1 parent 59002c3 commit a3152d8

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

src/diffusers/models/cross_attention.py

+3-9
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import torch
1818
import torch.nn.functional as F
1919
from torch import nn, Tensor
20-
import math
2120

2221
from ..utils.import_utils import is_xformers_available
2322

@@ -319,19 +318,14 @@ def __call__(
319318
_, k_tokens, _ = key.shape
320319
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
321320

322-
kv_chunk_size = min(self.kv_chunk_size or int(math.sqrt(k_tokens)), k_tokens)
323-
if self.kv_chunk_size_min is not None:
324-
kv_chunk_size = max(kv_chunk_size, self.kv_chunk_size_min)
325-
326-
uses_chunking = q_tokens > self.query_chunk_size or k_tokens > kv_chunk_size
327-
328-
if uses_chunking and (self.chunk_threshold_bytes is None or qk_matmul_size_bytes > self.chunk_threshold_bytes):
321+
if self.chunk_threshold_bytes is None or qk_matmul_size_bytes > self.chunk_threshold_bytes:
329322
hidden_states = efficient_dot_product_attention(
330323
query,
331324
key,
332325
value,
333326
query_chunk_size=self.query_chunk_size,
334-
kv_chunk_size=kv_chunk_size,
327+
kv_chunk_size=self.kv_chunk_size,
328+
kv_chunk_size_min=self.kv_chunk_size_min,
335329
use_checkpoint=attn.training,
336330
)
337331
else:

src/diffusers/models/sub_quadratic_attention.py

+6
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def efficient_dot_product_attention(
125125
value: Tensor,
126126
query_chunk_size=1024,
127127
kv_chunk_size: Optional[int] = None,
128+
kv_chunk_size_min: Optional[int] = None,
128129
use_checkpoint=True,
129130
):
130131
"""Computes efficient dot-product attention given query, key, and value.
@@ -139,6 +140,7 @@ def efficient_dot_product_attention(
139140
`[batch * num_heads, tokens, channels_per_head]`.
140141
query_chunk_size: int: query chunks size
141142
kv_chunk_size: Optional[int]: key/value chunks size. if None: defaults to sqrt(key_tokens)
143+
kv_chunk_size_min: Optional[int]: key/value minimum chunk size. only considered when kv_chunk_size is None. changes `sqrt(key_tokens)` into `max(sqrt(key_tokens), kv_chunk_size_min)`, to ensure our chunk sizes don't get too small (smaller chunks = more chunks = less concurrent work done).
142144
use_checkpoint: bool: whether to use checkpointing (recommended True for training, False for inference)
143145
Returns:
144146
Output of shape `[batch * num_heads, query_tokens, channels_per_head]`.
@@ -147,6 +149,10 @@ def efficient_dot_product_attention(
147149
_, k_tokens, _ = key.shape
148150
scale = q_channels_per_head ** -0.5
149151

152+
kv_chunk_size = min(kv_chunk_size or int(math.sqrt(k_tokens)), k_tokens)
153+
if kv_chunk_size_min is not None:
154+
kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min)
155+
150156
def get_query_chunk(chunk_idx: int) -> Tensor:
151157
return dynamic_slice(
152158
query,

0 commit comments

Comments
 (0)