1
1
# SPDX-License-Identifier: Apache-2.0
2
2
3
- import random
4
3
from typing import Optional
5
4
6
5
import pytest
@@ -171,19 +170,31 @@ def ref_context_attention(
171
170
return output
172
171
173
172
173
+ @pytest .mark .parametrize (
174
+ "block_size, large_tile_size" ,
175
+ [
176
+ (32 , 2048 ), # 64 blocks
177
+ (32 , 4096 ), # 128 blocks
178
+ (32 , 8192 ), # 256 blocks
179
+ (64 , 8192 ), # 128 blocks
180
+ ],
181
+ )
174
182
@pytest .mark .parametrize (
175
183
"num_heads,num_queries_per_kv,head_size,mixed_precision" ,
176
184
[
177
185
(4 , 2 , 8 , False ),
178
186
(4 , 2 , 8 , True ),
179
187
(32 , 8 , 64 , True ),
188
+ (16 , 2 , 128 , True ),
180
189
],
181
190
)
182
191
@torch .inference_mode ()
183
192
def test_contexted_kv_attention (
184
193
num_heads : int ,
185
194
num_queries_per_kv : int ,
186
195
head_size : int ,
196
+ block_size : int ,
197
+ large_tile_size ,
187
198
mixed_precision : bool ,
188
199
) -> None :
189
200
import os
@@ -192,40 +203,46 @@ def test_contexted_kv_attention(
192
203
193
204
from vllm .attention .ops .nki_flash_attn import flash_attn_varlen_nkifunc
194
205
206
+ assert large_tile_size % block_size == 0
207
+
195
208
device = xm .xla_device ()
196
209
197
- os .environ ["NEURON_CC_FLAGS" ] = (
198
- " --model-type=transformer -O1 "
199
- " --internal-hlo2tensorizer-options='--verify-hlo' " )
210
+ compiler_flags = [
211
+ "--model-type=transformer -O1" ,
212
+ "--internal-hlo2tensorizer-options='--verify-hlo'" ,
213
+ "--retry_failed_compilation" ,
214
+ ]
215
+ compiler_flags_str = " " .join (compiler_flags )
216
+ os .environ ["NEURON_CC_FLAGS" ] = compiler_flags_str
200
217
201
- random .seed (0 )
202
218
torch .manual_seed (0 )
203
219
torch .set_printoptions (sci_mode = False )
204
220
205
- min_ctx_len = 2
206
- max_ctx_len = 64
207
- min_query_len = 2
208
- max_query_len = 64
209
- prefill_batch_size = 2
210
- decode_batch_size = 6
221
+ min_ctx_len = 32
222
+ max_ctx_len = 1024
223
+ min_query_len = 16
224
+ max_query_len = 512
225
+ prefill_batch_size = 4
226
+ decode_batch_size = 12
211
227
batch_size = prefill_batch_size + decode_batch_size
212
- block_size = 32
213
228
max_model_len = (max_query_len + max_ctx_len ) * 4
214
229
215
230
max_block_per_request = max_model_len // block_size
216
231
dtype = torch .float32
217
232
cache_size = (batch_size * max_block_per_request ) + 2
218
- ctx_lens = [
219
- random .randint (min_ctx_len , max_ctx_len )
220
- for _ in range (prefill_batch_size )
221
- ] + [
222
- random .randint (min_ctx_len , max_ctx_len )
223
- for _ in range (decode_batch_size )
224
- ]
225
- query_lens = [
226
- random .randint (min_query_len , max_query_len )
227
- for _ in range (prefill_batch_size )
228
- ] + [1 for _ in range (decode_batch_size )]
233
+ prefill_ctx_lens = torch .randint (min_ctx_len ,
234
+ max_ctx_len + 1 , (prefill_batch_size , ),
235
+ dtype = torch .long ).tolist ()
236
+ decode_ctx_lens = torch .randint (min_ctx_len ,
237
+ max_ctx_len + 1 , (decode_batch_size , ),
238
+ dtype = torch .long ).tolist ()
239
+ ctx_lens = prefill_ctx_lens + decode_ctx_lens
240
+ query_lens = torch .randint (
241
+ min_query_len ,
242
+ max_query_len + 1 ,
243
+ (prefill_batch_size , ),
244
+ dtype = torch .long ,
245
+ ).tolist () + [1 for _ in range (decode_batch_size )]
229
246
seq_lens = [a + b for a , b in zip (query_lens , ctx_lens )]
230
247
num_kv_heads = num_heads // num_queries_per_kv
231
248
@@ -254,7 +271,6 @@ def test_contexted_kv_attention(
254
271
values = values [torch .randperm (cache_size )]
255
272
block_table = values [:batch_size * max_block_per_request ].view (
256
273
batch_size , max_block_per_request )
257
- torch .tensor (seq_lens , dtype = torch .long )
258
274
b_ctx_len = torch .tensor (ctx_lens , dtype = torch .long )
259
275
b_start_loc = torch .cumsum (torch .tensor ([0 ] + query_lens [:- 1 ],
260
276
dtype = torch .long ),
@@ -311,9 +327,7 @@ def test_contexted_kv_attention(
311
327
# build neuron program
312
328
return_debug_tensors = False
313
329
B_P_SIZE = 128
314
- LARGE_TILE_SZ = 2048
315
- max_num_queries = (
316
- (sum (query_lens ) + block_size - 1 ) // block_size ) * block_size
330
+ LARGE_TILE_SZ = large_tile_size
317
331
318
332
def get_active_block_tables (block_tables , query_lens , seq_lens , block_size ,
319
333
num_blocks ):
@@ -332,26 +346,28 @@ def get_active_block_tables(block_tables, query_lens, seq_lens, block_size,
332
346
0 ,
333
347
)
334
348
335
- def shift_bit_length (x ):
336
- return 1 << (x - 1 ).bit_length ()
349
+ def ceil_div (a , b ):
350
+ return (a + b - 1 ) // b
351
+
352
+ def pad_to_multiple (a , b ):
353
+ return ceil_div (a , b ) * b
354
+
355
+ def pad_to_next_power_of_2 (a ):
356
+ assert a > 0
357
+ return 2 ** int (a - 1 ).bit_length ()
337
358
338
359
# calculate input shapes
339
- max_num_queries_shifted = shift_bit_length (max_num_queries )
340
- max_num_queries_factor = B_P_SIZE // max_num_queries_shifted
341
- max_num_queries_padded = max_num_queries_shifted * max_num_queries_factor
342
- assert (max_num_queries_padded == B_P_SIZE
343
- ), "invalid {max_num_queries_padded=}"
360
+ max_num_queries = pad_to_multiple (sum (query_lens ), block_size )
361
+ max_num_queries = pad_to_next_power_of_2 (max_num_queries )
344
362
head_size_padded = B_P_SIZE
363
+ assert head_size_padded >= head_size
345
364
context_lens = torch .tensor (seq_lens ) - torch .tensor (query_lens )
346
- num_active_blocks_shifted = shift_bit_length (
347
- ((context_lens + block_size - 1 ) // block_size ).sum ().item ())
348
- num_active_blocks_factor = (LARGE_TILE_SZ // block_size //
349
- num_active_blocks_shifted )
350
- num_active_blocks = num_active_blocks_shifted * num_active_blocks_factor
351
- assert (num_active_blocks *
352
- block_size ) == LARGE_TILE_SZ , "invalid {num_active_blocks=}"
365
+ num_active_blocks = ceil_div (context_lens , block_size ).sum ().item ()
366
+ num_active_blocks = pad_to_multiple (num_active_blocks ,
367
+ LARGE_TILE_SZ // block_size )
353
368
context_kv_len = num_active_blocks * block_size
354
- assert context_kv_len == LARGE_TILE_SZ , f"invalid { context_kv_len = } "
369
+ assert (context_kv_len %
370
+ LARGE_TILE_SZ == 0 ), f"invalid context_kv_len={ context_kv_len } "
355
371
356
372
# pad QKV tensors
357
373
pad_dims = (
@@ -360,7 +376,7 @@ def shift_bit_length(x):
360
376
0 ,
361
377
0 ,
362
378
0 ,
363
- max_num_queries_padded - query .shape [0 ],
379
+ max_num_queries - query .shape [0 ],
364
380
)
365
381
query = F .pad (query , pad_dims , "constant" , 0 )
366
382
k = F .pad (k , pad_dims , "constant" , 0 )
@@ -397,7 +413,7 @@ def shift_bit_length(x):
397
413
0 ,
398
414
context_kv_len - prior_mask .shape [1 ],
399
415
0 ,
400
- B_P_SIZE - prior_mask .shape [0 ],
416
+ max_num_queries - prior_mask .shape [0 ],
401
417
),
402
418
"constant" ,
403
419
0 ,
@@ -406,9 +422,9 @@ def shift_bit_length(x):
406
422
active_mask ,
407
423
(
408
424
0 ,
409
- B_P_SIZE - active_mask .shape [1 ],
425
+ max_num_queries - active_mask .shape [1 ],
410
426
0 ,
411
- B_P_SIZE - active_mask .shape [0 ],
427
+ max_num_queries - active_mask .shape [0 ],
412
428
),
413
429
"constant" ,
414
430
0 ,
@@ -430,6 +446,8 @@ def shift_bit_length(x):
430
446
n_kv_head = num_kv_heads ,
431
447
head_size = head_size ,
432
448
mixed_precision = mixed_precision ,
449
+ LARGE_TILE_SZ = LARGE_TILE_SZ ,
450
+ return_debug_tensors = return_debug_tensors ,
433
451
)
434
452
435
453
if return_debug_tensors :
@@ -439,17 +457,15 @@ def shift_bit_length(x):
439
457
output_nki = flash_attn_varlen_nkifunc (* input_args , ** input_kwargs )
440
458
debug_tensors = []
441
459
442
- output_nki = torch .tensor (output_nki ).cpu ()
443
460
debug_tensors = [torch .tensor (dt ).cpu () for dt in debug_tensors ]
444
461
445
462
num_actual_tokens = sum (query_lens )
446
- print (f"{ num_actual_tokens = } " )
447
463
# - o: shape (bs, n_heads, seq_q, d) -> (bs, seq_q, n_heads, d)
448
- output_nki = output_nki .permute (
449
- 0 , 2 , 1 , 3 )[:, :, :, : head_size ]. cpu () [0 , :num_actual_tokens , :, :]
464
+ output_nki = output_nki .cpu (). permute (0 , 2 , 1 , 3 )[:, :, :, : head_size ]
465
+ output_nki = output_nki [0 , :num_actual_tokens , :, :]
450
466
output_ref_padded = F .pad (
451
467
output_ref ,
452
- (0 , 0 , 0 , 0 , 0 , 0 , 0 , max_num_queries_padded - output_ref .shape [0 ]),
468
+ (0 , 0 , 0 , 0 , 0 , 0 , 0 , max_num_queries - output_ref .shape [0 ]),
453
469
"constant" ,
454
470
0 ,
455
471
)
0 commit comments