Skip to content

Commit 9ec626e

Browse files
authored
Use pages_per_seq * page_size instead of directly passing max_model_len (#8950)
1 parent 80929a9 commit 9ec626e

File tree

2 files changed

+11
-14
lines changed

2 files changed

+11
-14
lines changed

test/test_pallas.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -678,7 +678,6 @@ def ragged_paged_attention_wrapper(
678678
sliding_window=sliding_window,
679679
soft_cap=soft_cap,
680680
use_kernel=True,
681-
max_model_len=2048,
682681
num_kv_pages_per_block=num_kv_pages_per_block,
683682
num_queries_per_block=num_queries_per_block,
684683
):
@@ -693,7 +692,6 @@ def ragged_paged_attention_wrapper(
693692
sliding_window=sliding_window,
694693
soft_cap=soft_cap,
695694
use_kernel=use_kernel,
696-
max_model_len=max_model_len,
697695
num_kv_pages_per_block=num_kv_pages_per_block,
698696
num_queries_per_block=num_queries_per_block,
699697
)
@@ -714,7 +712,6 @@ def ragged_paged_attention_wrapper(
714712
sliding_window=sliding_window,
715713
soft_cap=soft_cap,
716714
use_kernel=True,
717-
max_model_len=2048,
718715
num_kv_pages_per_block=num_kv_pages_per_block,
719716
num_queries_per_block=num_queries_per_block,
720717
)[:cu_q_lens[num_seqs]]
@@ -755,14 +752,15 @@ def ragged_paged_attention_wrapper(
755752

756753
from torch_xla.experimental.pallas_kernels.ragged_paged_attention_v2 import ragged_paged_attention as jax_ragged_paged_attention
757754
from torch_xla.experimental.tuned_block_sizes import get_ragged_attention_tuned_block_size
758-
max_model_len = 2048
759755
if num_kv_pages_per_block is None:
760756
assert num_queries_per_block is None
761-
token_num = q.shape[0]
762757
token_num, q_head_num, _ = q.shape
763-
kv_head_num = kv_pages[2] // 2
758+
_, page_size, num_combined_kv_heads, _ = kv_pages.shape
759+
_, pages_per_seq = page_indices.shape
760+
num_kv_heads = num_combined_kv_heads // 2
761+
max_model_len = pages_per_seq * page_size
764762
num_kv_pages_per_block, num_queries_per_block = get_ragged_attention_tuned_block_size(
765-
q_head_num, kv_head_num, token_num, max_model_len)
763+
q_head_num, num_kv_heads, token_num, max_model_len)
766764
jax_kernel_output = torch.from_numpy(
767765
np.array(
768766
jax_ragged_paged_attention(

torch_xla/experimental/custom_kernel.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -930,7 +930,6 @@ def ragged_paged_attention(
930930
soft_cap: float | None = None,
931931
mask_value=None,
932932
use_kernel=True,
933-
max_model_len=2048, # Used as a hint for the kernel block sizes selection
934933
# kernel tuning parameters
935934
num_kv_pages_per_block=None,
936935
num_queries_per_block=None,
@@ -960,9 +959,12 @@ def ragged_paged_attention(
960959
if num_kv_pages_per_block is None:
961960
assert num_queries_per_block is None
962961
token_num, q_head_num, _ = q.shape
963-
kv_head_num = kv_pages[2] // 2
962+
_, page_size, num_combined_kv_heads, _ = kv_pages.shape
963+
_, pages_per_seq = page_indices.shape
964+
num_kv_heads = num_combined_kv_heads // 2
965+
max_model_len = pages_per_seq * page_size
964966
num_kv_pages_per_block, num_queries_per_block = get_ragged_attention_tuned_block_size(
965-
q_head_num, kv_head_num, token_num, max_model_len)
967+
q_head_num, num_kv_heads, token_num, max_model_len)
966968

967969
if vmem_limit_bytes is None:
968970
vmem_limit_bytes = 64 * 1024 * 1024
@@ -1681,7 +1683,7 @@ def non_xla_ragged_paged_attention(q, kv, attention_type):
16811683
XLA_LIB.define(
16821684
"ragged_paged_attention(Tensor q, Tensor kv_pages, Tensor kv_lens, Tensor page_indices, "
16831685
"Tensor cu_q_lens, Tensor num_seqs, float sm_scale=1, int? sliding_window=None, "
1684-
"float? soft_cap=None, float? mask_value=None, bool use_kernel=True, int max_model_len=2048,"
1686+
"float? soft_cap=None, float? mask_value=None, bool use_kernel=True,"
16851687
"int? num_kv_pages_per_block=None, int? num_queries_per_block=None, int? vmem_limit_bytes=None) -> Tensor",
16861688
)
16871689

@@ -1699,7 +1701,6 @@ def ragged_paged_attention_xla(
16991701
soft_cap: float | None = None,
17001702
mask_value=None,
17011703
use_kernel=True,
1702-
max_model_len=2048,
17031704
# kernel tuning parameters
17041705
num_kv_pages_per_block=None,
17051706
num_queries_per_block=None,
@@ -1717,7 +1718,6 @@ def ragged_paged_attention_xla(
17171718
soft_cap=soft_cap,
17181719
mask_value=mask_value,
17191720
use_kernel=use_kernel,
1720-
max_model_len=max_model_len,
17211721
num_kv_pages_per_block=num_kv_pages_per_block,
17221722
num_queries_per_block=num_queries_per_block,
17231723
vmem_limit_bytes=vmem_limit_bytes)
@@ -1736,7 +1736,6 @@ def ragged_paged_attention_non_xla(
17361736
soft_cap: float | None = None,
17371737
mask_value=None,
17381738
use_kernel=True,
1739-
max_model_len=2048,
17401739
# kernel tuning parameters
17411740
num_kv_pages_per_block=None,
17421741
num_queries_per_block=None,

0 commit comments

Comments
 (0)