@@ -780,13 +780,14 @@ def ragged_paged_attention(
780
780
kv_lens , # i32[num_tokens]
781
781
page_indices , # i32[num_tokens, pages_per_sequence]
782
782
cu_q_lens , # i32[num_tokens + 1]
783
- num_seqs , # int
783
+ num_seqs , # i32[]
784
784
num_kv_pages_per_block ,
785
785
num_queries_per_block ,
786
786
use_kernel = True ,
787
787
# TODO(jevinjiang, xiowei): add attn_logits_soft_cap.
788
788
# attn_logits_soft_cap: float | None = None,
789
789
): # [batch_size, query_len, num_heads, head_dim]:
790
+ num_seqs = num_seqs .item ()
790
791
assert len (q .shape ) == 3 , "q should have 3 dimensions."
791
792
if not use_kernel :
792
793
return _ragged_paged_attention_nonkernel (
@@ -1541,15 +1542,15 @@ def multi_queries_paged_attention_non_xla(q: torch.Tensor,
1541
1542
1542
1543
1543
1544
XLA_LIB .define (
1544
- "ragged_paged_attention(Tensor q, Tensor k_pages, Tensor v_pages, Tensor kv_lens, Tensor page_indices, Tensor cu_q_lens, int num_seqs, int num_kv_pages_per_block, int num_queries_per_block, bool use_kernel) -> Tensor" ,
1545
+ "ragged_paged_attention(Tensor q, Tensor k_pages, Tensor v_pages, Tensor kv_lens, Tensor page_indices, Tensor cu_q_lens, Tensor num_seqs, int num_kv_pages_per_block, int num_queries_per_block, bool use_kernel) -> Tensor" ,
1545
1546
)
1546
1547
1547
1548
1548
1549
@impl (XLA_LIB , "ragged_paged_attention" , "XLA" )
1549
1550
def ragged_paged_attention_xla (q : torch .Tensor , k_pages : torch .Tensor ,
1550
1551
v_pages : torch .Tensor , kv_lens : torch .Tensor ,
1551
1552
page_indices : torch .Tensor ,
1552
- cu_q_lens : torch .Tensor , num_seqs : int ,
1553
+ cu_q_lens : torch .Tensor , num_seqs : torch . Tensor ,
1553
1554
num_kv_pages_per_block : int ,
1554
1555
num_queries_per_block : int , use_kernel : bool ):
1555
1556
return ragged_paged_attention (q , k_pages , v_pages , kv_lens , page_indices ,
@@ -1561,8 +1562,8 @@ def ragged_paged_attention_xla(q: torch.Tensor, k_pages: torch.Tensor,
1561
1562
def ragged_paged_attention_non_xla (
1562
1563
q : torch .Tensor , k_pages : torch .Tensor , v_pages : torch .Tensor ,
1563
1564
kv_lens : torch .Tensor , page_indices : torch .Tensor , cu_q_lens : torch .Tensor ,
1564
- num_seqs : int , num_kv_pages_per_block : int , num_queries_per_block : int ,
1565
- use_kernel : bool ):
1565
+ num_seqs : torch . Tensor , num_kv_pages_per_block : int ,
1566
+ num_queries_per_block : int , use_kernel : bool ):
1566
1567
return non_xla_attetion (q , k_pages , v_pages , "paged" )
1567
1568
1568
1569
0 commit comments