Skip to content

Commit b18a65f

Browse files
authored
Change num_seqs type from int to torch.Tensor (#8736)
1 parent 1ab8216 commit b18a65f

File tree

2 files changed

+12
-9
lines changed

2 files changed

+12
-9
lines changed

test/test_pallas.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -682,6 +682,7 @@ def test_ragged_paged_attention_wrapper_without_dynamo(self):
682682
kv_lens_xla = kv_lens.to("xla")
683683
page_indices_xla = page_indices.to("xla")
684684
cu_q_lens_xla = cu_q_lens.to("xla")
685+
num_seqs_xla = torch.tensor(num_seqs).to('xla')
685686

686687
output = ragged_paged_attention(
687688
q_xla,
@@ -690,7 +691,7 @@ def test_ragged_paged_attention_wrapper_without_dynamo(self):
690691
kv_lens_xla,
691692
page_indices_xla,
692693
cu_q_lens_xla,
693-
num_seqs=num_seqs,
694+
num_seqs=num_seqs_xla,
694695
num_kv_pages_per_block=num_kv_pages_per_block,
695696
num_queries_per_block=num_queries_per_block,
696697
use_kernel=True)
@@ -702,7 +703,7 @@ def test_ragged_paged_attention_wrapper_without_dynamo(self):
702703
kv_lens_xla,
703704
page_indices_xla,
704705
cu_q_lens_xla,
705-
num_seqs=num_seqs,
706+
num_seqs=num_seqs_xla,
706707
num_kv_pages_per_block=num_kv_pages_per_block,
707708
num_queries_per_block=num_queries_per_block,
708709
use_kernel=False)
@@ -764,6 +765,7 @@ def _verify_ragged_paged_attention_with_dynamo(
764765
kv_lens_xla = kv_lens.to("xla")
765766
page_indices_xla = page_indices.to("xla")
766767
cu_q_lens_xla = cu_q_lens.to("xla")
768+
num_seqs_xla = torch.tensor(num_seqs).to("xla")
767769

768770
def ragged_paged_attention_wrapper(q, k_pages, v_pages, kv_lens,
769771
page_indices, cu_q_lens, num_seqs,
@@ -792,7 +794,7 @@ def ragged_paged_attention_wrapper(q, k_pages, v_pages, kv_lens,
792794
kv_lens_xla,
793795
page_indices_xla,
794796
cu_q_lens_xla,
795-
num_seqs=num_seqs,
797+
num_seqs=num_seqs_xla,
796798
num_kv_pages_per_block=num_kv_pages_per_block,
797799
num_queries_per_block=num_queries_per_block,
798800
use_kernel=True,
@@ -805,7 +807,7 @@ def ragged_paged_attention_wrapper(q, k_pages, v_pages, kv_lens,
805807
kv_lens_xla,
806808
page_indices_xla,
807809
cu_q_lens_xla,
808-
num_seqs=num_seqs,
810+
num_seqs=num_seqs_xla,
809811
num_kv_pages_per_block=num_kv_pages_per_block,
810812
num_queries_per_block=num_queries_per_block,
811813
use_kernel=False,

torch_xla/experimental/custom_kernel.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -780,13 +780,14 @@ def ragged_paged_attention(
780780
kv_lens, # i32[num_tokens]
781781
page_indices, # i32[num_tokens, pages_per_sequence]
782782
cu_q_lens, # i32[num_tokens + 1]
783-
num_seqs, # int
783+
num_seqs, # i32[]
784784
num_kv_pages_per_block,
785785
num_queries_per_block,
786786
use_kernel=True,
787787
# TODO(jevinjiang, xiowei): add attn_logits_soft_cap.
788788
# attn_logits_soft_cap: float | None = None,
789789
): # [batch_size, query_len, num_heads, head_dim]:
790+
num_seqs = num_seqs.item()
790791
assert len(q.shape) == 3, "q should have 3 dimensions."
791792
if not use_kernel:
792793
return _ragged_paged_attention_nonkernel(
@@ -1541,15 +1542,15 @@ def multi_queries_paged_attention_non_xla(q: torch.Tensor,
15411542

15421543

15431544
XLA_LIB.define(
1544-
"ragged_paged_attention(Tensor q, Tensor k_pages, Tensor v_pages, Tensor kv_lens, Tensor page_indices, Tensor cu_q_lens, int num_seqs, int num_kv_pages_per_block, int num_queries_per_block, bool use_kernel) -> Tensor",
1545+
"ragged_paged_attention(Tensor q, Tensor k_pages, Tensor v_pages, Tensor kv_lens, Tensor page_indices, Tensor cu_q_lens, Tensor num_seqs, int num_kv_pages_per_block, int num_queries_per_block, bool use_kernel) -> Tensor",
15451546
)
15461547

15471548

15481549
@impl(XLA_LIB, "ragged_paged_attention", "XLA")
15491550
def ragged_paged_attention_xla(q: torch.Tensor, k_pages: torch.Tensor,
15501551
v_pages: torch.Tensor, kv_lens: torch.Tensor,
15511552
page_indices: torch.Tensor,
1552-
cu_q_lens: torch.Tensor, num_seqs: int,
1553+
cu_q_lens: torch.Tensor, num_seqs: torch.Tensor,
15531554
num_kv_pages_per_block: int,
15541555
num_queries_per_block: int, use_kernel: bool):
15551556
return ragged_paged_attention(q, k_pages, v_pages, kv_lens, page_indices,
@@ -1561,8 +1562,8 @@ def ragged_paged_attention_xla(q: torch.Tensor, k_pages: torch.Tensor,
15611562
def ragged_paged_attention_non_xla(
15621563
q: torch.Tensor, k_pages: torch.Tensor, v_pages: torch.Tensor,
15631564
kv_lens: torch.Tensor, page_indices: torch.Tensor, cu_q_lens: torch.Tensor,
1564-
num_seqs: int, num_kv_pages_per_block: int, num_queries_per_block: int,
1565-
use_kernel: bool):
1565+
num_seqs: torch.Tensor, num_kv_pages_per_block: int,
1566+
num_queries_per_block: int, use_kernel: bool):
15661567
return non_xla_attetion(q, k_pages, v_pages, "paged")
15671568

15681569

0 commit comments

Comments
 (0)