@@ -318,23 +318,24 @@ def __call__(
318
318
_ , k_tokens , _ = key .shape
319
319
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
320
320
321
- if self .chunk_threshold_bytes is None or qk_matmul_size_bytes > self .chunk_threshold_bytes :
322
- hidden_states = efficient_dot_product_attention (
323
- query ,
324
- key ,
325
- value ,
326
- query_chunk_size = self .query_chunk_size ,
327
- kv_chunk_size = self .kv_chunk_size ,
328
- kv_chunk_size_min = self .kv_chunk_size_min ,
329
- use_checkpoint = attn .training ,
330
- )
331
- else :
332
- # the big matmul fits into our memory limit; compute via unchunked attention (it's faster)
333
- attention_probs = attn .get_attention_scores (
334
- query ,
335
- key ,
336
- )
337
- hidden_states = torch .bmm (attention_probs , value )
321
+ query_chunk_size = self .query_chunk_size
322
+ kv_chunk_size = self .kv_chunk_size
323
+
324
+ if self .chunk_threshold_bytes is not None and qk_matmul_size_bytes <= self .chunk_threshold_bytes :
325
+ # the big matmul fits into our memory limit; do everything in 1 chunk,
326
+ # i.e. send it down the unchunked fast-path
327
+ query_chunk_size = q_tokens
328
+ kv_chunk_size = k_tokens
329
+
330
+ hidden_states = efficient_dot_product_attention (
331
+ query ,
332
+ key ,
333
+ value ,
334
+ query_chunk_size = query_chunk_size ,
335
+ kv_chunk_size = kv_chunk_size ,
336
+ kv_chunk_size_min = self .kv_chunk_size_min ,
337
+ use_checkpoint = attn .training ,
338
+ )
338
339
339
340
hidden_states = hidden_states .to (dtype )
340
341
0 commit comments