@@ -1681,7 +1681,7 @@ def non_xla_ragged_paged_attention(q, kv, attention_type):
1681
1681
XLA_LIB .define (
1682
1682
"ragged_paged_attention(Tensor q, Tensor kv_pages, Tensor kv_lens, Tensor page_indices, "
1683
1683
"Tensor cu_q_lens, Tensor num_seqs, float sm_scale=1, int? sliding_window=None, "
1684
- "float? soft_cap=None, float? mask_value=None, bool use_kernel=True, "
1684
+ "float? soft_cap=None, float? mask_value=None, bool use_kernel=True, int max_model_len=2048, "
1685
1685
"int? num_kv_pages_per_block=None, int? num_queries_per_block=None, int? vmem_limit_bytes=None) -> Tensor" ,
1686
1686
)
1687
1687
@@ -1699,6 +1699,7 @@ def ragged_paged_attention_xla(
1699
1699
soft_cap : float | None = None ,
1700
1700
mask_value = None ,
1701
1701
use_kernel = True ,
1702
+ max_model_len = 2048 ,
1702
1703
# kernel tuning parameters
1703
1704
num_kv_pages_per_block = None ,
1704
1705
num_queries_per_block = None ,
@@ -1716,6 +1717,7 @@ def ragged_paged_attention_xla(
1716
1717
soft_cap = soft_cap ,
1717
1718
mask_value = mask_value ,
1718
1719
use_kernel = use_kernel ,
1720
+ max_model_len = max_model_len ,
1719
1721
num_kv_pages_per_block = num_kv_pages_per_block ,
1720
1722
num_queries_per_block = num_queries_per_block ,
1721
1723
vmem_limit_bytes = vmem_limit_bytes )
@@ -1734,6 +1736,7 @@ def ragged_paged_attention_non_xla(
1734
1736
soft_cap : float | None = None ,
1735
1737
mask_value = None ,
1736
1738
use_kernel = True ,
1739
+ max_model_len = 2048 ,
1737
1740
# kernel tuning parameters
1738
1741
num_kv_pages_per_block = None ,
1739
1742
num_queries_per_block = None ,
0 commit comments