@@ -780,14 +780,13 @@ def ragged_paged_attention(
780
780
kv_lens , # i32[num_tokens]
781
781
page_indices , # i32[num_tokens, pages_per_sequence]
782
782
cu_q_lens , # i32[num_tokens + 1]
783
- num_seqs , # i32[]
783
+ num_seqs , # int
784
784
num_kv_pages_per_block ,
785
785
num_queries_per_block ,
786
786
use_kernel = True ,
787
787
# TODO(jevinjiang, xiowei): add attn_logits_soft_cap.
788
788
# attn_logits_soft_cap: float | None = None,
789
789
): # [batch_size, query_len, num_heads, head_dim]:
790
- num_seqs = num_seqs .item ()
791
790
assert len (q .shape ) == 3 , "q should have 3 dimensions."
792
791
if not use_kernel :
793
792
return _ragged_paged_attention_nonkernel (
@@ -1542,15 +1541,15 @@ def multi_queries_paged_attention_non_xla(q: torch.Tensor,
1542
1541
1543
1542
1544
1543
XLA_LIB .define (
1545
- "ragged_paged_attention(Tensor q, Tensor k_pages, Tensor v_pages, Tensor kv_lens, Tensor page_indices, Tensor cu_q_lens, Tensor num_seqs, int num_kv_pages_per_block, int num_queries_per_block, bool use_kernel) -> Tensor" ,
1544
+ "ragged_paged_attention(Tensor q, Tensor k_pages, Tensor v_pages, Tensor kv_lens, Tensor page_indices, Tensor cu_q_lens, int num_seqs, int num_kv_pages_per_block, int num_queries_per_block, bool use_kernel) -> Tensor" ,
1546
1545
)
1547
1546
1548
1547
1549
1548
@impl (XLA_LIB , "ragged_paged_attention" , "XLA" )
1550
1549
def ragged_paged_attention_xla (q : torch .Tensor , k_pages : torch .Tensor ,
1551
1550
v_pages : torch .Tensor , kv_lens : torch .Tensor ,
1552
1551
page_indices : torch .Tensor ,
1553
- cu_q_lens : torch .Tensor , num_seqs : torch . Tensor ,
1552
+ cu_q_lens : torch .Tensor , num_seqs : int ,
1554
1553
num_kv_pages_per_block : int ,
1555
1554
num_queries_per_block : int , use_kernel : bool ):
1556
1555
return ragged_paged_attention (q , k_pages , v_pages , kv_lens , page_indices ,
@@ -1562,8 +1561,8 @@ def ragged_paged_attention_xla(q: torch.Tensor, k_pages: torch.Tensor,
1562
1561
def ragged_paged_attention_non_xla (
1563
1562
q : torch .Tensor , k_pages : torch .Tensor , v_pages : torch .Tensor ,
1564
1563
kv_lens : torch .Tensor , page_indices : torch .Tensor , cu_q_lens : torch .Tensor ,
1565
- num_seqs : torch . Tensor , num_kv_pages_per_block : int ,
1566
- num_queries_per_block : int , use_kernel : bool ):
1564
+ num_seqs : int , num_kv_pages_per_block : int , num_queries_per_block : int ,
1565
+ use_kernel : bool ):
1567
1566
return non_xla_attetion (q , k_pages , v_pages , "paged" )
1568
1567
1569
1568
0 commit comments