Skip to content

Commit 4ea48fb

Browse files
authored
[V1][Minor] Move cascade attn logic outside _prepare_inputs (vllm-project#12943)
Signed-off-by: Woosuk Kwon <[email protected]>
1 parent e31498b commit 4ea48fb

File tree

1 file changed

+89
-61
lines changed

1 file changed

+89
-61
lines changed

vllm/v1/worker/gpu_model_runner.py

+89-61
Original file line numberDiff line numberDiff line change
@@ -476,67 +476,11 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
476476
self.device, non_blocking=True).long()
477477

478478
# Prepare for cascade attention if needed.
479-
common_prefix_len = (scheduler_output.num_common_prefix_blocks *
480-
self.block_size)
481-
if common_prefix_len == 0:
482-
# Common case.
483-
use_cascade = False
484-
else:
485-
# NOTE(woosuk): Cascade attention uses two attention kernels: one
486-
# for the common prefix and the other for the rest. For the first
487-
# kernel, we concatenate all the query tokens (possibly from
488-
# different requests) and treat them as if they are from the same
489-
# request. Then, we use bi-directional attention to process the
490-
# common prefix in the KV cache. Importantly, this means that the
491-
# first kernel does not do any masking.
492-
493-
# Consider the following example:
494-
# Request 1's input query: [D, E, X]
495-
# Request 1's kv cache: [A, B, C, D, E, X]
496-
# Request 1's num_computed_tokens: 3 (i.e., [A, B, C])
497-
# Request 2's input query: [E, Y]
498-
# Request 2's kv cache: [A, B, C, D, E, Y]
499-
# Request 2's num_computed_tokens: 4 (i.e., [A, B, C, D])
500-
501-
# If we use [A, B, C, D, E] as the common prefix, then the
502-
# first kernel will compute the bi-directional attention between
503-
# input query [D, E, X, E, Y] and common prefix [A, B, C, D, E].
504-
# However, this is wrong because D in Request 1 should not attend to
505-
# E in the common prefix (i.e., we need masking).
506-
# To avoid this, [A, B, C, D] should be the common prefix.
507-
# That is, the common prefix should be capped by the minimum
508-
# num_computed_tokens among the requests, and plus one to include
509-
# the first token of the query.
510-
511-
# In practice, we use [A, B, C] as the common prefix, instead of
512-
# [A, B, C, D] (i.e., the common prefix is capped by the minimum
513-
# num_computed_tokens, without plus one).
514-
# This is because of an implementation detail: We want to always
515-
# use two kernels for cascade attention. Let's imagine:
516-
# Request 3's input query: [D]
517-
# Request 3's kv cache: [A, B, C, D]
518-
# Request 3's num_computed_tokens: 4 (i.e., [A, B, C, D])
519-
# If we use [A, B, C, D] as the common prefix for Request 1-3,
520-
# then Request 3 will be processed only by the first kernel,
521-
# and the second kernel will get an empty input. While this is not
522-
# a fundamental problem, our current implementation does not support
523-
# this case.
524-
common_prefix_len = min(
525-
common_prefix_len,
526-
self.input_batch.num_computed_tokens_cpu[:num_reqs].min())
527-
# common_prefix_len should be a multiple of the block size.
528-
common_prefix_len = (common_prefix_len // self.block_size *
529-
self.block_size)
530-
use_cascade = FlashAttentionBackend.use_cascade_attention(
531-
common_prefix_len=common_prefix_len,
532-
query_lens=num_scheduled_tokens,
533-
num_query_heads=self.num_query_heads,
534-
num_kv_heads=self.num_kv_heads,
535-
use_alibi=False, # FIXME
536-
use_sliding_window=self.sliding_window is not None,
537-
num_sms=self.num_sms,
538-
)
539-
479+
common_prefix_len = self._compute_cascade_attn_prefix_len(
480+
num_scheduled_tokens,
481+
scheduler_output.num_common_prefix_blocks,
482+
)
483+
use_cascade = common_prefix_len > 0
540484
if use_cascade:
541485
# TODO: Optimize.
542486
cu_prefix_query_lens = torch.tensor(
@@ -581,6 +525,90 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
581525
logits_indices = query_start_loc[1:] - 1
582526
return attn_metadata, logits_indices
583527

528+
def _compute_cascade_attn_prefix_len(
529+
self,
530+
num_scheduled_tokens: np.ndarray,
531+
num_common_prefix_blocks: int,
532+
) -> int:
533+
"""Compute the length of the common prefix for cascade attention.
534+
535+
NOTE(woosuk): The common prefix length returned by this function
536+
represents the length used specifically for cascade attention, not the
537+
actual number of tokens shared between requests. When cascade attention
538+
is disabled (use_cascade=False), this function returns 0 even if
539+
requests share common tokens. Additionally, the common prefix length is
540+
truncated to a multiple of the block size and may be further truncated
541+
due to implementation details explained below.
542+
543+
Args:
544+
num_scheduled_tokens: Number of tokens scheduled per request.
545+
num_common_prefix_blocks: Number of shared KV cache blocks.
546+
547+
Returns:
548+
int: Length of common prefix in tokens.
549+
"""
550+
common_prefix_len = num_common_prefix_blocks * self.block_size
551+
if common_prefix_len == 0:
552+
# Common case.
553+
return 0
554+
555+
# NOTE(woosuk): Cascade attention uses two attention kernels: one
556+
# for the common prefix and the other for the rest. For the first
557+
# kernel, we concatenate all the query tokens (possibly from
558+
# different requests) and treat them as if they are from the same
559+
# request. Then, we use bi-directional attention to process the
560+
# common prefix in the KV cache. Importantly, this means that the
561+
# first kernel does not do any masking.
562+
563+
# Consider the following example:
564+
# Request 1's input query: [D, E, X]
565+
# Request 1's kv cache: [A, B, C, D, E, X]
566+
# Request 1's num_computed_tokens: 3 (i.e., [A, B, C])
567+
# Request 2's input query: [E, Y]
568+
# Request 2's kv cache: [A, B, C, D, E, Y]
569+
# Request 2's num_computed_tokens: 4 (i.e., [A, B, C, D])
570+
571+
# If we use [A, B, C, D, E] as the common prefix, then the
572+
# first kernel will compute the bi-directional attention between
573+
# input query [D, E, X, E, Y] and common prefix [A, B, C, D, E].
574+
# However, this is wrong because D in Request 1 should not attend to
575+
# E in the common prefix (i.e., we need masking).
576+
# To avoid this, [A, B, C, D] should be the common prefix.
577+
# That is, the common prefix should be capped by the minimum
578+
# num_computed_tokens among the requests, and plus one to include
579+
# the first token of the query.
580+
581+
# In practice, we use [A, B, C] as the common prefix, instead of
582+
# [A, B, C, D] (i.e., the common prefix is capped by the minimum
583+
# num_computed_tokens, without plus one).
584+
# This is because of an implementation detail: We want to always
585+
# use two kernels for cascade attention. Let's imagine:
586+
# Request 3's input query: [D]
587+
# Request 3's kv cache: [A, B, C, D]
588+
# Request 3's num_computed_tokens: 4 (i.e., [A, B, C, D])
589+
# If we use [A, B, C, D] as the common prefix for Request 1-3,
590+
# then Request 3 will be processed only by the first kernel,
591+
# and the second kernel will get an empty input. While this is not
592+
# a fundamental problem, our current implementation does not support
593+
# this case.
594+
num_reqs = len(num_scheduled_tokens)
595+
common_prefix_len = min(
596+
common_prefix_len,
597+
self.input_batch.num_computed_tokens_cpu[:num_reqs].min())
598+
# common_prefix_len should be a multiple of the block size.
599+
common_prefix_len = (common_prefix_len // self.block_size *
600+
self.block_size)
601+
use_cascade = FlashAttentionBackend.use_cascade_attention(
602+
common_prefix_len=common_prefix_len,
603+
query_lens=num_scheduled_tokens,
604+
num_query_heads=self.num_query_heads,
605+
num_kv_heads=self.num_kv_heads,
606+
use_alibi=False, # FIXME
607+
use_sliding_window=self.sliding_window is not None,
608+
num_sms=self.num_sms,
609+
)
610+
return common_prefix_len if use_cascade else 0
611+
584612
def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
585613
mrope_pos_ptr = 0
586614
num_reqs = self.input_batch.num_reqs

0 commit comments

Comments
 (0)