1
1
from typing import List , Optional , Tuple
2
+ import sys
3
+ import unittest
2
4
3
- from absl .testing import absltest
4
5
from absl .testing import parameterized
6
+ from absl .testing import absltest
5
7
import jax
6
8
from jax ._src import test_util as jtu
7
9
from jax .experimental .pallas .ops .tpu .paged_attention import quantization_utils
8
10
from torch_xla .experimental .pallas_kernels .ragged_paged_attention_kernel import ragged_paged_attention , make_sequence_metadata , DEFAULT_MASK_VALUE
9
11
import jax .numpy as jnp
10
12
import numpy as np
11
13
12
- jax .config .parse_flags_with_absl ()
13
-
14
14
ATOL_FP32 = 2e-1
15
15
16
16
@@ -29,6 +29,7 @@ def _ref_ragged_paged_attention(
29
29
assert num_q_heads % num_kv_heads == 0 , "num_q_heads % num_kv_heads !=0."
30
30
num_query_per_kv = num_q_heads // num_kv_heads
31
31
start_idx = 0
32
+
32
33
outputs : List [jax .Array ] = []
33
34
for i in range (num_seqs ):
34
35
cur_q_len = cu_q_lens [i + 1 ] - cu_q_lens [i ]
@@ -72,11 +73,17 @@ def _ref_ragged_paged_attention(
72
73
outputs .append (out )
73
74
start_idx += cur_q_len
74
75
76
+ maybe_padded_num_q_tokens = queries .shape [0 ]
77
+ actual_num_tokens = cu_q_lens [num_seqs ]
78
+ if actual_num_tokens < maybe_padded_num_q_tokens :
79
+ num_tokens_diff = maybe_padded_num_q_tokens - actual_num_tokens
80
+ outputs .append (
81
+ jnp .zeros (
82
+ (num_tokens_diff , num_q_heads , head_dim )).astype (outputs [0 ].dtype ))
75
83
return jnp .concatenate (outputs , axis = 0 )
76
84
77
85
78
- @jtu .with_config (jax_numpy_dtype_promotion = "standard" )
79
- class RaggedPagedAttentionKernelTest (jtu .JaxTestCase ):
86
+ class RaggedPagedAttentionKernelTest (parameterized .TestCase ):
80
87
81
88
def _verify_ragged_paged_attention (
82
89
self ,
@@ -88,6 +95,7 @@ def _verify_ragged_paged_attention(
88
95
num_pages ,
89
96
num_kv_pages_per_block = 128 ,
90
97
num_queries_per_block = 128 ,
98
+ pad_num_q_tokens = False ,
91
99
):
92
100
num_seqs = len (seq_lens )
93
101
# Make sure the q_len is no longer than the kv_len. For example,
@@ -99,7 +107,11 @@ def _verify_ragged_paged_attention(
99
107
assert cur_q_len <= cur_kv_len , f"cur_q_len must be less than or equal to cur_kv_len. Got { cur_q_len } and { cur_kv_len } "
100
108
101
109
query_lens = [seq_len [0 ] for seq_len in seq_lens ]
102
- num_q_tokens = sum (query_lens )
110
+ actual_num_q_tokens = sum (query_lens )
111
+ # Caller(eg vLLM) may decide to pad the num_q_tokens.
112
+ num_q_tokens = self ._round_up_closest_multiple_of (
113
+ actual_num_q_tokens ,
114
+ num_queries_per_block ) if pad_num_q_tokens else actual_num_q_tokens
103
115
kv_lens = jnp .array ([seq_len [1 ] for seq_len in seq_lens ])
104
116
num_q_heads = num_heads [0 ]
105
117
num_kv_heads = num_heads [1 ]
@@ -115,6 +127,8 @@ def _verify_ragged_paged_attention(
115
127
k3 , (num_kv_heads , num_pages , page_size , head_dim ), dtype = dtype )
116
128
117
129
# Create a kv_lens: i32[num_tokens]
130
+ # Only the first num_seqs of kv_lens_with_paddings are meaningful
131
+ # [num_seqs:num_q_tokens] are padded value and are meaningless.
118
132
kv_lens_with_paddings = [0 ] * num_q_tokens
119
133
for i in range (num_seqs ):
120
134
kv_lens_with_paddings [i ] = kv_lens [i ]
@@ -182,8 +196,16 @@ def _verify_ragged_paged_attention(
182
196
rtol = 1e-1
183
197
else :
184
198
self .fail (f'Unsupported dtype: { dtype } ' )
185
- self .assertTrue (
186
- jnp .allclose (actual_output , expected_output , atol = atol , rtol = rtol ))
199
+ if pad_num_q_tokens :
200
+ self .assertTrue (
201
+ jnp .allclose (
202
+ actual_output [:actual_num_q_tokens ],
203
+ expected_output [:actual_num_q_tokens ],
204
+ atol = atol ,
205
+ rtol = rtol ))
206
+ else :
207
+ self .assertTrue (
208
+ jnp .allclose (actual_output , expected_output , atol = atol , rtol = rtol ))
187
209
188
210
def _round_up_closest_multiple_of (self , x , base ):
189
211
return (x + base - 1 ) // base * base
@@ -215,11 +237,12 @@ def test_paged_attention_basic(self,):
215
237
216
238
@parameterized .product (
217
239
seq_lens = [[(1 , 1328 ), (5 , 18 ), (506 , 563 )]],
218
- num_heads = [(4 , 4 ), (8 , 2 ), ( 16 , 2 )],
240
+ num_heads = [(4 , 4 ), (4 , 2 )],
219
241
head_dim = [128 , 256 ],
220
242
dtype = (jnp .float32 , jnp .bfloat16 ),
221
243
page_size = [16 , 32 ],
222
244
num_pages = [32768 , 2048 ],
245
+ num_queries_per_block = [16 , 64 , 128 ],
223
246
)
224
247
def test_paged_attention_varlen_comprehensive (
225
248
self ,
@@ -229,6 +252,7 @@ def test_paged_attention_varlen_comprehensive(
229
252
dtype ,
230
253
page_size : int ,
231
254
num_pages : int ,
255
+ num_queries_per_block : int ,
232
256
):
233
257
if jtu .is_device_tpu (version = 4 ) and head_dim == 256 and page_size == 32 :
234
258
self .skipTest (
@@ -240,7 +264,42 @@ def test_paged_attention_varlen_comprehensive(
240
264
page_size ,
241
265
dtype ,
242
266
num_pages ,
243
- num_queries_per_block = 64 ,
267
+ num_queries_per_block = num_queries_per_block ,
268
+ num_kv_pages_per_block = 128 ,
269
+ )
270
+
271
+ @parameterized .product (
272
+ num_heads = [(4 , 4 ), (4 , 2 )],
273
+ head_dim = [128 , 256 ],
274
+ dtype = (jnp .float32 , jnp .bfloat16 ),
275
+ page_size = [16 , 32 ],
276
+ num_pages = [32768 , 2048 ],
277
+ num_queries_per_block = [16 , 64 , 128 ],
278
+ )
279
+ def test_paged_attention_varlen_with_padding_comprehensive (
280
+ self ,
281
+ num_heads : Tuple [int , int ],
282
+ head_dim : int ,
283
+ dtype ,
284
+ page_size : int ,
285
+ num_pages : int ,
286
+ num_queries_per_block : int ,
287
+ ):
288
+ if jtu .is_device_tpu (version = 4 ) and head_dim == 256 and page_size == 32 :
289
+ self .skipTest (
290
+ "TPU v4 has small VMEM. It will run into VMEM OOM. Skip the test." )
291
+ # If num_queries_per_block is 128, then num_tokens will be pad 6 to be the smallest multiple of 128.
292
+ seq_lens = [(1 , 1328 ), (5 , 18 ), (500 , 563 )]
293
+ self ._verify_ragged_paged_attention (
294
+ seq_lens ,
295
+ num_heads ,
296
+ head_dim ,
297
+ page_size ,
298
+ dtype ,
299
+ num_pages ,
300
+ num_queries_per_block = num_queries_per_block ,
301
+ num_kv_pages_per_block = 128 ,
302
+ pad_num_q_tokens = True ,
244
303
)
245
304
246
305
def test_paged_attention_mix_prefill_and_decode1 (self ,):
@@ -442,4 +501,5 @@ def test_make_sequence_metadata(self,):
442
501
443
502
444
503
if __name__ == "__main__" :
445
- absltest .main (testLoader = jtu .JaxTestLoader ())
504
+ test = unittest .main ()
505
+ sys .exit (0 if test .result .wasSuccessful () else 1 )
0 commit comments