Skip to content

Commit 84bf1c0

Browse files
committed
de-duplicate fast-path for "matmul < quota". we can just ask for everything in one chunk, to re-use an existing fast-path.
1 parent d67f8b9 commit 84bf1c0

File tree

1 file changed

+18
-17
lines changed

1 file changed

+18
-17
lines changed

src/diffusers/models/cross_attention.py

+18-17
Original file line numberDiff line numberDiff line change
@@ -318,23 +318,24 @@ def __call__(
318318
_, k_tokens, _ = key.shape
319319
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
320320

321-
if self.chunk_threshold_bytes is None or qk_matmul_size_bytes > self.chunk_threshold_bytes:
322-
hidden_states = efficient_dot_product_attention(
323-
query,
324-
key,
325-
value,
326-
query_chunk_size=self.query_chunk_size,
327-
kv_chunk_size=self.kv_chunk_size,
328-
kv_chunk_size_min=self.kv_chunk_size_min,
329-
use_checkpoint=attn.training,
330-
)
331-
else:
332-
# the big matmul fits into our memory limit; compute via unchunked attention (it's faster)
333-
attention_probs = attn.get_attention_scores(
334-
query,
335-
key,
336-
)
337-
hidden_states = torch.bmm(attention_probs, value)
321+
query_chunk_size = self.query_chunk_size
322+
kv_chunk_size = self.kv_chunk_size
323+
324+
if self.chunk_threshold_bytes is not None and qk_matmul_size_bytes <= self.chunk_threshold_bytes:
325+
# the big matmul fits into our memory limit; do everything in 1 chunk,
326+
# i.e. send it down the unchunked fast-path
327+
query_chunk_size = q_tokens
328+
kv_chunk_size = k_tokens
329+
330+
hidden_states = efficient_dot_product_attention(
331+
query,
332+
key,
333+
value,
334+
query_chunk_size=query_chunk_size,
335+
kv_chunk_size=kv_chunk_size,
336+
kv_chunk_size_min=self.kv_chunk_size_min,
337+
use_checkpoint=attn.training,
338+
)
338339

339340
hidden_states = hidden_states.to(dtype)
340341

0 commit comments

Comments
 (0)