Skip to content

Commit 2e4f073

Browse files
authored
Revert "Change num_seqs type from int to torch.Tensor" (#8767)
1 parent 6712eb9 commit 2e4f073

File tree

2 files changed

+9
-12
lines changed

2 files changed

+9
-12
lines changed

test/test_pallas.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -682,7 +682,6 @@ def test_ragged_paged_attention_wrapper_without_dynamo(self):
682682
kv_lens_xla = kv_lens.to("xla")
683683
page_indices_xla = page_indices.to("xla")
684684
cu_q_lens_xla = cu_q_lens.to("xla")
685-
num_seqs_xla = torch.tensor(num_seqs).to('xla')
686685

687686
output = ragged_paged_attention(
688687
q_xla,
@@ -691,7 +690,7 @@ def test_ragged_paged_attention_wrapper_without_dynamo(self):
691690
kv_lens_xla,
692691
page_indices_xla,
693692
cu_q_lens_xla,
694-
num_seqs=num_seqs_xla,
693+
num_seqs=num_seqs,
695694
num_kv_pages_per_block=num_kv_pages_per_block,
696695
num_queries_per_block=num_queries_per_block,
697696
use_kernel=True)
@@ -703,7 +702,7 @@ def test_ragged_paged_attention_wrapper_without_dynamo(self):
703702
kv_lens_xla,
704703
page_indices_xla,
705704
cu_q_lens_xla,
706-
num_seqs=num_seqs_xla,
705+
num_seqs=num_seqs,
707706
num_kv_pages_per_block=num_kv_pages_per_block,
708707
num_queries_per_block=num_queries_per_block,
709708
use_kernel=False)
@@ -765,7 +764,6 @@ def _verify_ragged_paged_attention_with_dynamo(
765764
kv_lens_xla = kv_lens.to("xla")
766765
page_indices_xla = page_indices.to("xla")
767766
cu_q_lens_xla = cu_q_lens.to("xla")
768-
num_seqs_xla = torch.tensor(num_seqs).to("xla")
769767

770768
def ragged_paged_attention_wrapper(q, k_pages, v_pages, kv_lens,
771769
page_indices, cu_q_lens, num_seqs,
@@ -794,7 +792,7 @@ def ragged_paged_attention_wrapper(q, k_pages, v_pages, kv_lens,
794792
kv_lens_xla,
795793
page_indices_xla,
796794
cu_q_lens_xla,
797-
num_seqs=num_seqs_xla,
795+
num_seqs=num_seqs,
798796
num_kv_pages_per_block=num_kv_pages_per_block,
799797
num_queries_per_block=num_queries_per_block,
800798
use_kernel=True,
@@ -807,7 +805,7 @@ def ragged_paged_attention_wrapper(q, k_pages, v_pages, kv_lens,
807805
kv_lens_xla,
808806
page_indices_xla,
809807
cu_q_lens_xla,
810-
num_seqs=num_seqs_xla,
808+
num_seqs=num_seqs,
811809
num_kv_pages_per_block=num_kv_pages_per_block,
812810
num_queries_per_block=num_queries_per_block,
813811
use_kernel=False,

torch_xla/experimental/custom_kernel.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -780,14 +780,13 @@ def ragged_paged_attention(
780780
kv_lens, # i32[num_tokens]
781781
page_indices, # i32[num_tokens, pages_per_sequence]
782782
cu_q_lens, # i32[num_tokens + 1]
783-
num_seqs, # i32[]
783+
num_seqs, # int
784784
num_kv_pages_per_block,
785785
num_queries_per_block,
786786
use_kernel=True,
787787
# TODO(jevinjiang, xiowei): add attn_logits_soft_cap.
788788
# attn_logits_soft_cap: float | None = None,
789789
): # [batch_size, query_len, num_heads, head_dim]:
790-
num_seqs = num_seqs.item()
791790
assert len(q.shape) == 3, "q should have 3 dimensions."
792791
if not use_kernel:
793792
return _ragged_paged_attention_nonkernel(
@@ -1542,15 +1541,15 @@ def multi_queries_paged_attention_non_xla(q: torch.Tensor,
15421541

15431542

15441543
XLA_LIB.define(
1545-
"ragged_paged_attention(Tensor q, Tensor k_pages, Tensor v_pages, Tensor kv_lens, Tensor page_indices, Tensor cu_q_lens, Tensor num_seqs, int num_kv_pages_per_block, int num_queries_per_block, bool use_kernel) -> Tensor",
1544+
"ragged_paged_attention(Tensor q, Tensor k_pages, Tensor v_pages, Tensor kv_lens, Tensor page_indices, Tensor cu_q_lens, int num_seqs, int num_kv_pages_per_block, int num_queries_per_block, bool use_kernel) -> Tensor",
15461545
)
15471546

15481547

15491548
@impl(XLA_LIB, "ragged_paged_attention", "XLA")
15501549
def ragged_paged_attention_xla(q: torch.Tensor, k_pages: torch.Tensor,
15511550
v_pages: torch.Tensor, kv_lens: torch.Tensor,
15521551
page_indices: torch.Tensor,
1553-
cu_q_lens: torch.Tensor, num_seqs: torch.Tensor,
1552+
cu_q_lens: torch.Tensor, num_seqs: int,
15541553
num_kv_pages_per_block: int,
15551554
num_queries_per_block: int, use_kernel: bool):
15561555
return ragged_paged_attention(q, k_pages, v_pages, kv_lens, page_indices,
@@ -1562,8 +1561,8 @@ def ragged_paged_attention_xla(q: torch.Tensor, k_pages: torch.Tensor,
15621561
def ragged_paged_attention_non_xla(
15631562
q: torch.Tensor, k_pages: torch.Tensor, v_pages: torch.Tensor,
15641563
kv_lens: torch.Tensor, page_indices: torch.Tensor, cu_q_lens: torch.Tensor,
1565-
num_seqs: torch.Tensor, num_kv_pages_per_block: int,
1566-
num_queries_per_block: int, use_kernel: bool):
1564+
num_seqs: int, num_kv_pages_per_block: int, num_queries_per_block: int,
1565+
use_kernel: bool):
15671566
return non_xla_attetion(q, k_pages, v_pages, "paged")
15681567

15691568

0 commit comments

Comments
 (0)