@@ -930,7 +930,6 @@ def ragged_paged_attention(
930
930
soft_cap : float | None = None ,
931
931
mask_value = None ,
932
932
use_kernel = True ,
933
- max_model_len = 2048 , # Used as a hint for the kernel block sizes selection
934
933
# kernel tuning parameters
935
934
num_kv_pages_per_block = None ,
936
935
num_queries_per_block = None ,
@@ -960,9 +959,12 @@ def ragged_paged_attention(
960
959
if num_kv_pages_per_block is None :
961
960
assert num_queries_per_block is None
962
961
token_num , q_head_num , _ = q .shape
963
- kv_head_num = kv_pages [2 ] // 2
962
+ _ , page_size , num_combined_kv_heads , _ = kv_pages .shape
963
+ _ , pages_per_seq = page_indices .shape
964
+ num_kv_heads = num_combined_kv_heads // 2
965
+ max_model_len = pages_per_seq * page_size
964
966
num_kv_pages_per_block , num_queries_per_block = get_ragged_attention_tuned_block_size (
965
- q_head_num , kv_head_num , token_num , max_model_len )
967
+ q_head_num , num_kv_heads , token_num , max_model_len )
966
968
967
969
if vmem_limit_bytes is None :
968
970
vmem_limit_bytes = 64 * 1024 * 1024
@@ -1681,7 +1683,7 @@ def non_xla_ragged_paged_attention(q, kv, attention_type):
1681
1683
XLA_LIB .define (
1682
1684
"ragged_paged_attention(Tensor q, Tensor kv_pages, Tensor kv_lens, Tensor page_indices, "
1683
1685
"Tensor cu_q_lens, Tensor num_seqs, float sm_scale=1, int? sliding_window=None, "
1684
- "float? soft_cap=None, float? mask_value=None, bool use_kernel=True, int max_model_len=2048, "
1686
+ "float? soft_cap=None, float? mask_value=None, bool use_kernel=True,"
1685
1687
"int? num_kv_pages_per_block=None, int? num_queries_per_block=None, int? vmem_limit_bytes=None) -> Tensor" ,
1686
1688
)
1687
1689
@@ -1699,7 +1701,6 @@ def ragged_paged_attention_xla(
1699
1701
soft_cap : float | None = None ,
1700
1702
mask_value = None ,
1701
1703
use_kernel = True ,
1702
- max_model_len = 2048 ,
1703
1704
# kernel tuning parameters
1704
1705
num_kv_pages_per_block = None ,
1705
1706
num_queries_per_block = None ,
@@ -1717,7 +1718,6 @@ def ragged_paged_attention_xla(
1717
1718
soft_cap = soft_cap ,
1718
1719
mask_value = mask_value ,
1719
1720
use_kernel = use_kernel ,
1720
- max_model_len = max_model_len ,
1721
1721
num_kv_pages_per_block = num_kv_pages_per_block ,
1722
1722
num_queries_per_block = num_queries_per_block ,
1723
1723
vmem_limit_bytes = vmem_limit_bytes )
@@ -1736,7 +1736,6 @@ def ragged_paged_attention_non_xla(
1736
1736
soft_cap : float | None = None ,
1737
1737
mask_value = None ,
1738
1738
use_kernel = True ,
1739
- max_model_len = 2048 ,
1740
1739
# kernel tuning parameters
1741
1740
num_kv_pages_per_block = None ,
1742
1741
num_queries_per_block = None ,
0 commit comments