Skip to content

Commit 53ed842

Browse files
vanbasten23pgmoka
authored andcommitted
Add padding ragged paged attention test (#8741)
1 parent fa30a7d commit 53ed842

File tree

3 files changed

+222
-45
lines changed

3 files changed

+222
-45
lines changed

test/test_pallas.py

Lines changed: 141 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ def _pagedattention_generate_qkv(
8787
q = torch.randn(batch_size, query_len, num_heads, head_dim, dtype=dtype)
8888
return q, k_pages, v_pages, page_indices
8989

90+
def _round_up_closest_multiple_of(self, x, base):
91+
return (x + base - 1) // base * base
92+
9093
def _ragged_pagedattention_generate_qkv(
9194
self,
9295
seq_lens,
@@ -95,6 +98,8 @@ def _ragged_pagedattention_generate_qkv(
9598
page_size,
9699
num_pages,
97100
dtype=torch.float32,
101+
num_queries_per_block=None,
102+
pad_num_q_tokens=False,
98103
):
99104
num_seqs = len(seq_lens)
100105
# Make sure the q_len is no longer than the kv_len. For example,
@@ -106,7 +111,10 @@ def _ragged_pagedattention_generate_qkv(
106111
assert cur_q_len <= cur_kv_len, f"cur_q_len must be less than or equal to cur_kv_len. Got {cur_q_len} and {cur_kv_len}"
107112

108113
query_lens = [seq_len[0] for seq_len in seq_lens]
109-
num_q_tokens = sum(query_lens)
114+
actual_num_q_tokens = sum(query_lens)
115+
num_q_tokens = self._round_up_closest_multiple_of(
116+
actual_num_q_tokens,
117+
num_queries_per_block) if pad_num_q_tokens else actual_num_q_tokens
110118
kv_lens = torch.tensor([seq_len[1] for seq_len in seq_lens],
111119
dtype=torch.int32)
112120
num_q_heads = num_heads[0]
@@ -727,34 +735,28 @@ def test_ragged_paged_attention_wrapper_without_dynamo(self):
727735
torch.allclose(
728736
output.cpu(), nonkernel_output.cpu(), atol=2e-1, rtol=1e-2))
729737

730-
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
731-
"This test only works on TPUv4+.")
732-
def test_ragged_paged_attention_wrapper_with_dynamo(self):
733-
seq_lens = [
734-
(1, 1328),
735-
(5, 18),
736-
(1, 129),
737-
(120, 229),
738-
(1, 122), # first physical q block
739-
(1, 64),
740-
(32, 100),
741-
(250, 463),
742-
(1, 18),
743-
(1, 17),
744-
(99, 123)
745-
] # last 3 physical q blocks [(q_len, kv_len),...]
746-
num_heads = (4, 4)
747-
head_dim = 128
748-
dtype = torch.float32
749-
page_size = 16
750-
num_pages = 32768
738+
def _verify_ragged_paged_attention_with_dynamo(
739+
self,
740+
seq_lens,
741+
num_heads,
742+
head_dim,
743+
page_size,
744+
num_pages,
745+
dtype,
746+
num_kv_pages_per_block,
747+
num_queries_per_block,
748+
pad_num_q_tokens=False,
749+
):
751750
num_seqs = len(seq_lens)
752-
num_kv_pages_per_block = 128
753-
num_queries_per_block = 8
754-
block_kv_size = 256
755-
756751
q, k_pages, v_pages, page_indices, cu_q_lens, kv_lens = self._ragged_pagedattention_generate_qkv(
757-
seq_lens, num_heads, head_dim, page_size, num_pages, dtype=dtype)
752+
seq_lens,
753+
num_heads,
754+
head_dim,
755+
page_size,
756+
num_pages,
757+
dtype=dtype,
758+
num_queries_per_block=num_queries_per_block,
759+
pad_num_q_tokens=pad_num_q_tokens)
758760

759761
q_xla = q.to("xla")
760762
k_pages_xla = k_pages.to("xla")
@@ -783,7 +785,7 @@ def ragged_paged_attention_wrapper(q, k_pages, v_pages, kv_lens,
783785
compiled_paged_attention = torch.compile(
784786
ragged_paged_attention_wrapper, backend="openxla")
785787

786-
output = compiled_paged_attention(
788+
kernel_output = compiled_paged_attention(
787789
q_xla,
788790
k_pages_xla,
789791
v_pages_xla,
@@ -809,9 +811,117 @@ def ragged_paged_attention_wrapper(q, k_pages, v_pages, kv_lens,
809811
use_kernel=False,
810812
)
811813

812-
self.assertTrue(
813-
torch.allclose(
814-
output.cpu(), nonkernel_output.cpu(), atol=2e-1, rtol=1e-2))
814+
kernel_output_cpu = kernel_output.cpu()
815+
nonkernel_output_cpu = nonkernel_output.cpu()
816+
self.assertEqual(kernel_output_cpu.shape, nonkernel_output_cpu.shape)
817+
self.assertEqual(kernel_output_cpu.dtype, nonkernel_output_cpu.dtype)
818+
819+
q_jax = jnp.array(q.numpy(), dtype=jnp.float32)
820+
k_pages_jax = jnp.array(k_pages.numpy(), dtype=jnp.float32)
821+
v_pages_jax = jnp.array(v_pages.numpy(), dtype=jnp.float32)
822+
kv_lens_jax = jnp.array(kv_lens.numpy(), dtype=jnp.int32)
823+
page_indices_jax = jnp.array(page_indices.numpy(), dtype=jnp.int32)
824+
cu_q_lens_jax = jnp.array(cu_q_lens.numpy(), dtype=jnp.int32)
825+
826+
from torch_xla.experimental.pallas_kernels.ragged_paged_attention_kernel import ragged_paged_attention as jax_ragged_paged_attention
827+
jax_kernel_output = torch.from_numpy(
828+
np.array(
829+
jax_ragged_paged_attention(
830+
q_jax,
831+
k_pages_jax,
832+
v_pages_jax,
833+
kv_lens_jax,
834+
page_indices_jax,
835+
cu_q_lens_jax,
836+
num_seqs=num_seqs,
837+
num_kv_pages_per_block=num_kv_pages_per_block,
838+
num_queries_per_block=num_queries_per_block,
839+
)[1]))
840+
jax_kernel_output_cpu = jax_kernel_output.cpu()
841+
842+
if pad_num_q_tokens:
843+
actual_num_q_tokens = cu_q_lens[num_seqs]
844+
self.assertTrue(
845+
torch.allclose(
846+
kernel_output_cpu[:actual_num_q_tokens],
847+
nonkernel_output_cpu[:actual_num_q_tokens],
848+
atol=2e-1,
849+
rtol=1e-2))
850+
self.assertTrue(
851+
torch.allclose(
852+
kernel_output_cpu[:actual_num_q_tokens],
853+
jax_kernel_output_cpu[:actual_num_q_tokens],
854+
atol=2e-1,
855+
rtol=1e-2))
856+
else:
857+
self.assertTrue(
858+
torch.allclose(
859+
kernel_output_cpu, nonkernel_output_cpu, atol=2e-1, rtol=1e-2))
860+
self.assertTrue(
861+
torch.allclose(
862+
kernel_output_cpu, jax_kernel_output_cpu, atol=2e-1, rtol=1e-2))
863+
864+
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
865+
"This test only works on TPUv4+.")
866+
def test_ragged_paged_attention_wrapper_no_query_padding_with_dynamo(self):
867+
seq_lens = [
868+
(1, 1328),
869+
(5, 18),
870+
(1, 129),
871+
(120, 229),
872+
(1, 122), # first physical q block
873+
(1, 64),
874+
(32, 100),
875+
(250, 463),
876+
(1, 18),
877+
(1, 17),
878+
(99, 123)
879+
] # last 3 physical q blocks [(q_len, kv_len),...]
880+
num_heads = (4, 4)
881+
head_dim = 128
882+
dtype = torch.float32
883+
page_size = 16
884+
num_pages = 32768
885+
886+
self._verify_ragged_paged_attention_with_dynamo(
887+
seq_lens,
888+
num_heads,
889+
head_dim,
890+
page_size,
891+
num_pages,
892+
dtype,
893+
num_kv_pages_per_block=128,
894+
num_queries_per_block=8,
895+
)
896+
897+
@parameterized.product(
898+
seq_lens=[[(1, 1328), (5, 18), (500, 563)]],
899+
num_queries_per_block=[16, 64, 128],
900+
)
901+
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
902+
"This test only works on TPUv4+.")
903+
def test_ragged_paged_attention_wrapper_with_query_padding_with_dynamo(
904+
self,
905+
seq_lens,
906+
num_queries_per_block,
907+
):
908+
num_heads = (4, 4)
909+
head_dim = 128
910+
dtype = torch.float32
911+
page_size = 16
912+
num_pages = 32768
913+
914+
self._verify_ragged_paged_attention_with_dynamo(
915+
seq_lens,
916+
num_heads,
917+
head_dim,
918+
page_size,
919+
num_pages,
920+
dtype,
921+
num_kv_pages_per_block=128,
922+
num_queries_per_block=num_queries_per_block,
923+
pad_num_q_tokens=True,
924+
)
815925

816926
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
817927
"This test only works on TPUv4+.")

test/test_ragged_paged_attention_kernel.py

Lines changed: 71 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
from typing import List, Optional, Tuple
2+
import sys
3+
import unittest
24

3-
from absl.testing import absltest
45
from absl.testing import parameterized
6+
from absl.testing import absltest
57
import jax
68
from jax._src import test_util as jtu
79
from jax.experimental.pallas.ops.tpu.paged_attention import quantization_utils
810
from torch_xla.experimental.pallas_kernels.ragged_paged_attention_kernel import ragged_paged_attention, make_sequence_metadata, DEFAULT_MASK_VALUE
911
import jax.numpy as jnp
1012
import numpy as np
1113

12-
jax.config.parse_flags_with_absl()
13-
1414
ATOL_FP32 = 2e-1
1515

1616

@@ -29,6 +29,7 @@ def _ref_ragged_paged_attention(
2929
assert num_q_heads % num_kv_heads == 0, "num_q_heads % num_kv_heads !=0."
3030
num_query_per_kv = num_q_heads // num_kv_heads
3131
start_idx = 0
32+
3233
outputs: List[jax.Array] = []
3334
for i in range(num_seqs):
3435
cur_q_len = cu_q_lens[i + 1] - cu_q_lens[i]
@@ -72,11 +73,17 @@ def _ref_ragged_paged_attention(
7273
outputs.append(out)
7374
start_idx += cur_q_len
7475

76+
maybe_padded_num_q_tokens = queries.shape[0]
77+
actual_num_tokens = cu_q_lens[num_seqs]
78+
if actual_num_tokens < maybe_padded_num_q_tokens:
79+
num_tokens_diff = maybe_padded_num_q_tokens - actual_num_tokens
80+
outputs.append(
81+
jnp.zeros(
82+
(num_tokens_diff, num_q_heads, head_dim)).astype(outputs[0].dtype))
7583
return jnp.concatenate(outputs, axis=0)
7684

7785

78-
@jtu.with_config(jax_numpy_dtype_promotion="standard")
79-
class RaggedPagedAttentionKernelTest(jtu.JaxTestCase):
86+
class RaggedPagedAttentionKernelTest(parameterized.TestCase):
8087

8188
def _verify_ragged_paged_attention(
8289
self,
@@ -88,6 +95,7 @@ def _verify_ragged_paged_attention(
8895
num_pages,
8996
num_kv_pages_per_block=128,
9097
num_queries_per_block=128,
98+
pad_num_q_tokens=False,
9199
):
92100
num_seqs = len(seq_lens)
93101
# Make sure the q_len is no longer than the kv_len. For example,
@@ -99,7 +107,11 @@ def _verify_ragged_paged_attention(
99107
assert cur_q_len <= cur_kv_len, f"cur_q_len must be less than or equal to cur_kv_len. Got {cur_q_len} and {cur_kv_len}"
100108

101109
query_lens = [seq_len[0] for seq_len in seq_lens]
102-
num_q_tokens = sum(query_lens)
110+
actual_num_q_tokens = sum(query_lens)
111+
# Caller(eg vLLM) may decide to pad the num_q_tokens.
112+
num_q_tokens = self._round_up_closest_multiple_of(
113+
actual_num_q_tokens,
114+
num_queries_per_block) if pad_num_q_tokens else actual_num_q_tokens
103115
kv_lens = jnp.array([seq_len[1] for seq_len in seq_lens])
104116
num_q_heads = num_heads[0]
105117
num_kv_heads = num_heads[1]
@@ -115,6 +127,8 @@ def _verify_ragged_paged_attention(
115127
k3, (num_kv_heads, num_pages, page_size, head_dim), dtype=dtype)
116128

117129
# Create a kv_lens: i32[num_tokens]
130+
# Only the first num_seqs of kv_lens_with_paddings are meaningful
131+
# [num_seqs:num_q_tokens] are padded value and are meaningless.
118132
kv_lens_with_paddings = [0] * num_q_tokens
119133
for i in range(num_seqs):
120134
kv_lens_with_paddings[i] = kv_lens[i]
@@ -182,8 +196,16 @@ def _verify_ragged_paged_attention(
182196
rtol = 1e-1
183197
else:
184198
self.fail(f'Unsupported dtype: {dtype}')
185-
self.assertTrue(
186-
jnp.allclose(actual_output, expected_output, atol=atol, rtol=rtol))
199+
if pad_num_q_tokens:
200+
self.assertTrue(
201+
jnp.allclose(
202+
actual_output[:actual_num_q_tokens],
203+
expected_output[:actual_num_q_tokens],
204+
atol=atol,
205+
rtol=rtol))
206+
else:
207+
self.assertTrue(
208+
jnp.allclose(actual_output, expected_output, atol=atol, rtol=rtol))
187209

188210
def _round_up_closest_multiple_of(self, x, base):
189211
return (x + base - 1) // base * base
@@ -215,11 +237,12 @@ def test_paged_attention_basic(self,):
215237

216238
@parameterized.product(
217239
seq_lens=[[(1, 1328), (5, 18), (506, 563)]],
218-
num_heads=[(4, 4), (8, 2), (16, 2)],
240+
num_heads=[(4, 4), (4, 2)],
219241
head_dim=[128, 256],
220242
dtype=(jnp.float32, jnp.bfloat16),
221243
page_size=[16, 32],
222244
num_pages=[32768, 2048],
245+
num_queries_per_block=[16, 64, 128],
223246
)
224247
def test_paged_attention_varlen_comprehensive(
225248
self,
@@ -229,6 +252,7 @@ def test_paged_attention_varlen_comprehensive(
229252
dtype,
230253
page_size: int,
231254
num_pages: int,
255+
num_queries_per_block: int,
232256
):
233257
if jtu.is_device_tpu(version=4) and head_dim == 256 and page_size == 32:
234258
self.skipTest(
@@ -240,7 +264,42 @@ def test_paged_attention_varlen_comprehensive(
240264
page_size,
241265
dtype,
242266
num_pages,
243-
num_queries_per_block=64,
267+
num_queries_per_block=num_queries_per_block,
268+
num_kv_pages_per_block=128,
269+
)
270+
271+
@parameterized.product(
272+
num_heads=[(4, 4), (4, 2)],
273+
head_dim=[128, 256],
274+
dtype=(jnp.float32, jnp.bfloat16),
275+
page_size=[16, 32],
276+
num_pages=[32768, 2048],
277+
num_queries_per_block=[16, 64, 128],
278+
)
279+
def test_paged_attention_varlen_with_padding_comprehensive(
280+
self,
281+
num_heads: Tuple[int, int],
282+
head_dim: int,
283+
dtype,
284+
page_size: int,
285+
num_pages: int,
286+
num_queries_per_block: int,
287+
):
288+
if jtu.is_device_tpu(version=4) and head_dim == 256 and page_size == 32:
289+
self.skipTest(
290+
"TPU v4 has small VMEM. It will run into VMEM OOM. Skip the test.")
291+
# If num_queries_per_block is 128, then num_tokens will be pad 6 to be the smallest multiple of 128.
292+
seq_lens = [(1, 1328), (5, 18), (500, 563)]
293+
self._verify_ragged_paged_attention(
294+
seq_lens,
295+
num_heads,
296+
head_dim,
297+
page_size,
298+
dtype,
299+
num_pages,
300+
num_queries_per_block=num_queries_per_block,
301+
num_kv_pages_per_block=128,
302+
pad_num_q_tokens=True,
244303
)
245304

246305
def test_paged_attention_mix_prefill_and_decode1(self,):
@@ -442,4 +501,5 @@ def test_make_sequence_metadata(self,):
442501

443502

444503
if __name__ == "__main__":
445-
absltest.main(testLoader=jtu.JaxTestLoader())
504+
test = unittest.main()
505+
sys.exit(0 if test.result.wasSuccessful() else 1)

0 commit comments

Comments
 (0)