Skip to content

Use new tuned table #9041

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 3 additions & 23 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,8 +637,8 @@ def _test_ragged_paged_attention(
sm_scale=1.0,
sliding_window=None,
soft_cap=None,
num_kv_pages_per_block=16,
num_queries_per_block=128,
num_kv_pages_per_block=None,
num_queries_per_block=None,
pad_tokens_and_seqs=False,
use_dynamo=True,
):
Expand Down Expand Up @@ -751,16 +751,6 @@ def ragged_paged_attention_wrapper(
num_seqs_jax = jnp.array([num_seqs], dtype=jnp.int32)

from torch_xla.experimental.pallas_kernels.ragged_paged_attention_v2 import ragged_paged_attention as jax_ragged_paged_attention
from torch_xla.experimental.tuned_block_sizes import get_ragged_attention_tuned_block_size
if num_kv_pages_per_block is None:
assert num_queries_per_block is None
token_num, q_head_num, _ = q.shape
_, page_size, num_combined_kv_heads, _ = kv_pages.shape
_, pages_per_seq = page_indices.shape
num_kv_heads = num_combined_kv_heads // 2
max_model_len = pages_per_seq * page_size
num_kv_pages_per_block, num_queries_per_block = get_ragged_attention_tuned_block_size(
q_head_num, num_kv_heads, token_num, max_model_len)
jax_kernel_output = torch.from_numpy(
np.array(
jax_ragged_paged_attention(
Expand Down Expand Up @@ -790,8 +780,7 @@ def ragged_paged_attention_wrapper(
sm_scale=[1.0, 0.5],
sliding_window=[None, 128],
soft_cap=[None, 10.0],
pad_tokens_and_seqs=[False, True],
block_sizes=[(16, 128), (None, None)])
pad_tokens_and_seqs=[False, True])
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
"This test only works on TPUv4+.")
def test_ragged_paged_attention_wrapper_with_dynamo(
Expand All @@ -803,12 +792,10 @@ def test_ragged_paged_attention_wrapper_with_dynamo(
sliding_window,
soft_cap,
pad_tokens_and_seqs,
block_sizes,
):
head_dim = 128
page_size = 16
num_pages = 1000
num_kv_pages_per_block, num_queries_per_block = block_sizes

self._test_ragged_paged_attention(
seq_lens,
Expand All @@ -822,8 +809,6 @@ def test_ragged_paged_attention_wrapper_with_dynamo(
soft_cap=soft_cap,
pad_tokens_and_seqs=pad_tokens_and_seqs,
use_dynamo=True,
num_kv_pages_per_block=num_kv_pages_per_block,
num_queries_per_block=num_queries_per_block,
)

@parameterized.product(
Expand All @@ -834,7 +819,6 @@ def test_ragged_paged_attention_wrapper_with_dynamo(
sliding_window=[None, 128],
soft_cap=[None, 10.0],
pad_tokens_and_seqs=[False, True],
block_sizes=[(16, 128), (None, None)],
)
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
"This test only works on TPUv4+.")
Expand All @@ -847,12 +831,10 @@ def test_ragged_paged_attention_wrapper_without_dynamo(
sliding_window,
soft_cap,
pad_tokens_and_seqs,
block_sizes,
):
head_dim = 128
page_size = 16
num_pages = 1000
num_kv_pages_per_block, num_queries_per_block = block_sizes

self._test_ragged_paged_attention(
seq_lens,
Expand All @@ -866,8 +848,6 @@ def test_ragged_paged_attention_wrapper_without_dynamo(
soft_cap=soft_cap,
pad_tokens_and_seqs=pad_tokens_and_seqs,
use_dynamo=False,
num_kv_pages_per_block=num_kv_pages_per_block,
num_queries_per_block=num_queries_per_block,
)

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
Expand Down
11 changes: 0 additions & 11 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from torch_xla.distributed.spmd import Mesh
import torch_xla.distributed.spmd as xs
from torch_xla._internal.jax_workarounds import requires_jax
from torch_xla.experimental.tuned_block_sizes import get_ragged_attention_tuned_block_size

# Re-expose this API used that is referenced by docs
from torch_xla._internal.jax_workarounds import jax_import_guard # noqa: F401, pylint: disable=unused-import
Expand Down Expand Up @@ -956,16 +955,6 @@ def ragged_paged_attention(
# in the global scope which could cause problems for xmp.spawn.
from torch_xla.experimental.pallas_kernels.ragged_paged_attention_v2 import ragged_paged_attention as ragged_attention

if num_kv_pages_per_block is None:
assert num_queries_per_block is None
token_num, q_head_num, _ = q.shape
_, page_size, num_combined_kv_heads, _ = kv_pages.shape
_, pages_per_seq = page_indices.shape
num_kv_heads = num_combined_kv_heads // 2
max_model_len = pages_per_seq * page_size
num_kv_pages_per_block, num_queries_per_block = get_ragged_attention_tuned_block_size(
q_head_num, num_kv_heads, token_num, max_model_len)

if vmem_limit_bytes is None:
vmem_limit_bytes = 64 * 1024 * 1024

Expand Down
Loading