@@ -476,67 +476,11 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
476
476
self .device , non_blocking = True ).long ()
477
477
478
478
# Prepare for cascade attention if needed.
479
- common_prefix_len = (scheduler_output .num_common_prefix_blocks *
480
- self .block_size )
481
- if common_prefix_len == 0 :
482
- # Common case.
483
- use_cascade = False
484
- else :
485
- # NOTE(woosuk): Cascade attention uses two attention kernels: one
486
- # for the common prefix and the other for the rest. For the first
487
- # kernel, we concatenate all the query tokens (possibly from
488
- # different requests) and treat them as if they are from the same
489
- # request. Then, we use bi-directional attention to process the
490
- # common prefix in the KV cache. Importantly, this means that the
491
- # first kernel does not do any masking.
492
-
493
- # Consider the following example:
494
- # Request 1's input query: [D, E, X]
495
- # Request 1's kv cache: [A, B, C, D, E, X]
496
- # Request 1's num_computed_tokens: 3 (i.e., [A, B, C])
497
- # Request 2's input query: [E, Y]
498
- # Request 2's kv cache: [A, B, C, D, E, Y]
499
- # Request 2's num_computed_tokens: 4 (i.e., [A, B, C, D])
500
-
501
- # If we use [A, B, C, D, E] as the common prefix, then the
502
- # first kernel will compute the bi-directional attention between
503
- # input query [D, E, X, E, Y] and common prefix [A, B, C, D, E].
504
- # However, this is wrong because D in Request 1 should not attend to
505
- # E in the common prefix (i.e., we need masking).
506
- # To avoid this, [A, B, C, D] should be the common prefix.
507
- # That is, the common prefix should be capped by the minimum
508
- # num_computed_tokens among the requests, and plus one to include
509
- # the first token of the query.
510
-
511
- # In practice, we use [A, B, C] as the common prefix, instead of
512
- # [A, B, C, D] (i.e., the common prefix is capped by the minimum
513
- # num_computed_tokens, without plus one).
514
- # This is because of an implementation detail: We want to always
515
- # use two kernels for cascade attention. Let's imagine:
516
- # Request 3's input query: [D]
517
- # Request 3's kv cache: [A, B, C, D]
518
- # Request 3's num_computed_tokens: 4 (i.e., [A, B, C, D])
519
- # If we use [A, B, C, D] as the common prefix for Request 1-3,
520
- # then Request 3 will be processed only by the first kernel,
521
- # and the second kernel will get an empty input. While this is not
522
- # a fundamental problem, our current implementation does not support
523
- # this case.
524
- common_prefix_len = min (
525
- common_prefix_len ,
526
- self .input_batch .num_computed_tokens_cpu [:num_reqs ].min ())
527
- # common_prefix_len should be a multiple of the block size.
528
- common_prefix_len = (common_prefix_len // self .block_size *
529
- self .block_size )
530
- use_cascade = FlashAttentionBackend .use_cascade_attention (
531
- common_prefix_len = common_prefix_len ,
532
- query_lens = num_scheduled_tokens ,
533
- num_query_heads = self .num_query_heads ,
534
- num_kv_heads = self .num_kv_heads ,
535
- use_alibi = False , # FIXME
536
- use_sliding_window = self .sliding_window is not None ,
537
- num_sms = self .num_sms ,
538
- )
539
-
479
+ common_prefix_len = self ._compute_cascade_attn_prefix_len (
480
+ num_scheduled_tokens ,
481
+ scheduler_output .num_common_prefix_blocks ,
482
+ )
483
+ use_cascade = common_prefix_len > 0
540
484
if use_cascade :
541
485
# TODO: Optimize.
542
486
cu_prefix_query_lens = torch .tensor (
@@ -581,6 +525,90 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
581
525
logits_indices = query_start_loc [1 :] - 1
582
526
return attn_metadata , logits_indices
583
527
528
+ def _compute_cascade_attn_prefix_len (
529
+ self ,
530
+ num_scheduled_tokens : np .ndarray ,
531
+ num_common_prefix_blocks : int ,
532
+ ) -> int :
533
+ """Compute the length of the common prefix for cascade attention.
534
+
535
+ NOTE(woosuk): The common prefix length returned by this function
536
+ represents the length used specifically for cascade attention, not the
537
+ actual number of tokens shared between requests. When cascade attention
538
+ is disabled (use_cascade=False), this function returns 0 even if
539
+ requests share common tokens. Additionally, the common prefix length is
540
+ truncated to a multiple of the block size and may be further truncated
541
+ due to implementation details explained below.
542
+
543
+ Args:
544
+ num_scheduled_tokens: Number of tokens scheduled per request.
545
+ num_common_prefix_blocks: Number of shared KV cache blocks.
546
+
547
+ Returns:
548
+ int: Length of common prefix in tokens.
549
+ """
550
+ common_prefix_len = num_common_prefix_blocks * self .block_size
551
+ if common_prefix_len == 0 :
552
+ # Common case.
553
+ return 0
554
+
555
+ # NOTE(woosuk): Cascade attention uses two attention kernels: one
556
+ # for the common prefix and the other for the rest. For the first
557
+ # kernel, we concatenate all the query tokens (possibly from
558
+ # different requests) and treat them as if they are from the same
559
+ # request. Then, we use bi-directional attention to process the
560
+ # common prefix in the KV cache. Importantly, this means that the
561
+ # first kernel does not do any masking.
562
+
563
+ # Consider the following example:
564
+ # Request 1's input query: [D, E, X]
565
+ # Request 1's kv cache: [A, B, C, D, E, X]
566
+ # Request 1's num_computed_tokens: 3 (i.e., [A, B, C])
567
+ # Request 2's input query: [E, Y]
568
+ # Request 2's kv cache: [A, B, C, D, E, Y]
569
+ # Request 2's num_computed_tokens: 4 (i.e., [A, B, C, D])
570
+
571
+ # If we use [A, B, C, D, E] as the common prefix, then the
572
+ # first kernel will compute the bi-directional attention between
573
+ # input query [D, E, X, E, Y] and common prefix [A, B, C, D, E].
574
+ # However, this is wrong because D in Request 1 should not attend to
575
+ # E in the common prefix (i.e., we need masking).
576
+ # To avoid this, [A, B, C, D] should be the common prefix.
577
+ # That is, the common prefix should be capped by the minimum
578
+ # num_computed_tokens among the requests, and plus one to include
579
+ # the first token of the query.
580
+
581
+ # In practice, we use [A, B, C] as the common prefix, instead of
582
+ # [A, B, C, D] (i.e., the common prefix is capped by the minimum
583
+ # num_computed_tokens, without plus one).
584
+ # This is because of an implementation detail: We want to always
585
+ # use two kernels for cascade attention. Let's imagine:
586
+ # Request 3's input query: [D]
587
+ # Request 3's kv cache: [A, B, C, D]
588
+ # Request 3's num_computed_tokens: 4 (i.e., [A, B, C, D])
589
+ # If we use [A, B, C, D] as the common prefix for Request 1-3,
590
+ # then Request 3 will be processed only by the first kernel,
591
+ # and the second kernel will get an empty input. While this is not
592
+ # a fundamental problem, our current implementation does not support
593
+ # this case.
594
+ num_reqs = len (num_scheduled_tokens )
595
+ common_prefix_len = min (
596
+ common_prefix_len ,
597
+ self .input_batch .num_computed_tokens_cpu [:num_reqs ].min ())
598
+ # common_prefix_len should be a multiple of the block size.
599
+ common_prefix_len = (common_prefix_len // self .block_size *
600
+ self .block_size )
601
+ use_cascade = FlashAttentionBackend .use_cascade_attention (
602
+ common_prefix_len = common_prefix_len ,
603
+ query_lens = num_scheduled_tokens ,
604
+ num_query_heads = self .num_query_heads ,
605
+ num_kv_heads = self .num_kv_heads ,
606
+ use_alibi = False , # FIXME
607
+ use_sliding_window = self .sliding_window is not None ,
608
+ num_sms = self .num_sms ,
609
+ )
610
+ return common_prefix_len if use_cascade else 0
611
+
584
612
def _calc_mrope_positions (self , scheduler_output : "SchedulerOutput" ):
585
613
mrope_pos_ptr = 0
586
614
num_reqs = self .input_batch .num_reqs
0 commit comments