Skip to content

Commit 80929a9

Browse files
authored
Fix ragged_paged_attention op signature (#8943)
1 parent ac9a39f commit 80929a9

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

test/test_pallas.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -678,6 +678,7 @@ def ragged_paged_attention_wrapper(
678678
sliding_window=sliding_window,
679679
soft_cap=soft_cap,
680680
use_kernel=True,
681+
max_model_len=2048,
681682
num_kv_pages_per_block=num_kv_pages_per_block,
682683
num_queries_per_block=num_queries_per_block,
683684
):
@@ -692,6 +693,7 @@ def ragged_paged_attention_wrapper(
692693
sliding_window=sliding_window,
693694
soft_cap=soft_cap,
694695
use_kernel=use_kernel,
696+
max_model_len=max_model_len,
695697
num_kv_pages_per_block=num_kv_pages_per_block,
696698
num_queries_per_block=num_queries_per_block,
697699
)
@@ -712,6 +714,7 @@ def ragged_paged_attention_wrapper(
712714
sliding_window=sliding_window,
713715
soft_cap=soft_cap,
714716
use_kernel=True,
717+
max_model_len=2048,
715718
num_kv_pages_per_block=num_kv_pages_per_block,
716719
num_queries_per_block=num_queries_per_block,
717720
)[:cu_q_lens[num_seqs]]
@@ -752,12 +755,12 @@ def ragged_paged_attention_wrapper(
752755

753756
from torch_xla.experimental.pallas_kernels.ragged_paged_attention_v2 import ragged_paged_attention as jax_ragged_paged_attention
754757
from torch_xla.experimental.tuned_block_sizes import get_ragged_attention_tuned_block_size
758+
max_model_len = 2048
755759
if num_kv_pages_per_block is None:
756760
assert num_queries_per_block is None
757761
token_num = q.shape[0]
758762
token_num, q_head_num, _ = q.shape
759763
kv_head_num = kv_pages[2] // 2
760-
max_model_len = 2048
761764
num_kv_pages_per_block, num_queries_per_block = get_ragged_attention_tuned_block_size(
762765
q_head_num, kv_head_num, token_num, max_model_len)
763766
jax_kernel_output = torch.from_numpy(

torch_xla/experimental/custom_kernel.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1681,7 +1681,7 @@ def non_xla_ragged_paged_attention(q, kv, attention_type):
16811681
XLA_LIB.define(
16821682
"ragged_paged_attention(Tensor q, Tensor kv_pages, Tensor kv_lens, Tensor page_indices, "
16831683
"Tensor cu_q_lens, Tensor num_seqs, float sm_scale=1, int? sliding_window=None, "
1684-
"float? soft_cap=None, float? mask_value=None, bool use_kernel=True, "
1684+
"float? soft_cap=None, float? mask_value=None, bool use_kernel=True, int max_model_len=2048,"
16851685
"int? num_kv_pages_per_block=None, int? num_queries_per_block=None, int? vmem_limit_bytes=None) -> Tensor",
16861686
)
16871687

@@ -1699,6 +1699,7 @@ def ragged_paged_attention_xla(
16991699
soft_cap: float | None = None,
17001700
mask_value=None,
17011701
use_kernel=True,
1702+
max_model_len=2048,
17021703
# kernel tuning parameters
17031704
num_kv_pages_per_block=None,
17041705
num_queries_per_block=None,
@@ -1716,6 +1717,7 @@ def ragged_paged_attention_xla(
17161717
soft_cap=soft_cap,
17171718
mask_value=mask_value,
17181719
use_kernel=use_kernel,
1720+
max_model_len=max_model_len,
17191721
num_kv_pages_per_block=num_kv_pages_per_block,
17201722
num_queries_per_block=num_queries_per_block,
17211723
vmem_limit_bytes=vmem_limit_bytes)
@@ -1734,6 +1736,7 @@ def ragged_paged_attention_non_xla(
17341736
soft_cap: float | None = None,
17351737
mask_value=None,
17361738
use_kernel=True,
1739+
max_model_len=2048,
17371740
# kernel tuning parameters
17381741
num_kv_pages_per_block=None,
17391742
num_queries_per_block=None,

0 commit comments

Comments
 (0)