Skip to content

Commit e92694b

Browse files
authored
[Neuron][Kernel] Support Longer Sequences in NKI-based Flash PagedAttention and Improve Efficiency (#12921)
Signed-off-by: Lingfan Yu <[email protected]>
1 parent 842b0fd commit e92694b

File tree

2 files changed

+154
-180
lines changed

2 files changed

+154
-180
lines changed

tests/neuron/test_prefix_prefill.py

+67-51
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3-
import random
43
from typing import Optional
54

65
import pytest
@@ -171,19 +170,31 @@ def ref_context_attention(
171170
return output
172171

173172

173+
@pytest.mark.parametrize(
174+
"block_size, large_tile_size",
175+
[
176+
(32, 2048), # 64 blocks
177+
(32, 4096), # 128 blocks
178+
(32, 8192), # 256 blocks
179+
(64, 8192), # 128 blocks
180+
],
181+
)
174182
@pytest.mark.parametrize(
175183
"num_heads,num_queries_per_kv,head_size,mixed_precision",
176184
[
177185
(4, 2, 8, False),
178186
(4, 2, 8, True),
179187
(32, 8, 64, True),
188+
(16, 2, 128, True),
180189
],
181190
)
182191
@torch.inference_mode()
183192
def test_contexted_kv_attention(
184193
num_heads: int,
185194
num_queries_per_kv: int,
186195
head_size: int,
196+
block_size: int,
197+
large_tile_size,
187198
mixed_precision: bool,
188199
) -> None:
189200
import os
@@ -192,40 +203,46 @@ def test_contexted_kv_attention(
192203

193204
from vllm.attention.ops.nki_flash_attn import flash_attn_varlen_nkifunc
194205

206+
assert large_tile_size % block_size == 0
207+
195208
device = xm.xla_device()
196209

197-
os.environ["NEURON_CC_FLAGS"] = (
198-
" --model-type=transformer -O1 "
199-
" --internal-hlo2tensorizer-options='--verify-hlo' ")
210+
compiler_flags = [
211+
"--model-type=transformer -O1",
212+
"--internal-hlo2tensorizer-options='--verify-hlo'",
213+
"--retry_failed_compilation",
214+
]
215+
compiler_flags_str = " ".join(compiler_flags)
216+
os.environ["NEURON_CC_FLAGS"] = compiler_flags_str
200217

201-
random.seed(0)
202218
torch.manual_seed(0)
203219
torch.set_printoptions(sci_mode=False)
204220

205-
min_ctx_len = 2
206-
max_ctx_len = 64
207-
min_query_len = 2
208-
max_query_len = 64
209-
prefill_batch_size = 2
210-
decode_batch_size = 6
221+
min_ctx_len = 32
222+
max_ctx_len = 1024
223+
min_query_len = 16
224+
max_query_len = 512
225+
prefill_batch_size = 4
226+
decode_batch_size = 12
211227
batch_size = prefill_batch_size + decode_batch_size
212-
block_size = 32
213228
max_model_len = (max_query_len + max_ctx_len) * 4
214229

215230
max_block_per_request = max_model_len // block_size
216231
dtype = torch.float32
217232
cache_size = (batch_size * max_block_per_request) + 2
218-
ctx_lens = [
219-
random.randint(min_ctx_len, max_ctx_len)
220-
for _ in range(prefill_batch_size)
221-
] + [
222-
random.randint(min_ctx_len, max_ctx_len)
223-
for _ in range(decode_batch_size)
224-
]
225-
query_lens = [
226-
random.randint(min_query_len, max_query_len)
227-
for _ in range(prefill_batch_size)
228-
] + [1 for _ in range(decode_batch_size)]
233+
prefill_ctx_lens = torch.randint(min_ctx_len,
234+
max_ctx_len + 1, (prefill_batch_size, ),
235+
dtype=torch.long).tolist()
236+
decode_ctx_lens = torch.randint(min_ctx_len,
237+
max_ctx_len + 1, (decode_batch_size, ),
238+
dtype=torch.long).tolist()
239+
ctx_lens = prefill_ctx_lens + decode_ctx_lens
240+
query_lens = torch.randint(
241+
min_query_len,
242+
max_query_len + 1,
243+
(prefill_batch_size, ),
244+
dtype=torch.long,
245+
).tolist() + [1 for _ in range(decode_batch_size)]
229246
seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)]
230247
num_kv_heads = num_heads // num_queries_per_kv
231248

@@ -254,7 +271,6 @@ def test_contexted_kv_attention(
254271
values = values[torch.randperm(cache_size)]
255272
block_table = values[:batch_size * max_block_per_request].view(
256273
batch_size, max_block_per_request)
257-
torch.tensor(seq_lens, dtype=torch.long)
258274
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
259275
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1],
260276
dtype=torch.long),
@@ -311,9 +327,7 @@ def test_contexted_kv_attention(
311327
# build neuron program
312328
return_debug_tensors = False
313329
B_P_SIZE = 128
314-
LARGE_TILE_SZ = 2048
315-
max_num_queries = (
316-
(sum(query_lens) + block_size - 1) // block_size) * block_size
330+
LARGE_TILE_SZ = large_tile_size
317331

318332
def get_active_block_tables(block_tables, query_lens, seq_lens, block_size,
319333
num_blocks):
@@ -332,26 +346,28 @@ def get_active_block_tables(block_tables, query_lens, seq_lens, block_size,
332346
0,
333347
)
334348

335-
def shift_bit_length(x):
336-
return 1 << (x - 1).bit_length()
349+
def ceil_div(a, b):
350+
return (a + b - 1) // b
351+
352+
def pad_to_multiple(a, b):
353+
return ceil_div(a, b) * b
354+
355+
def pad_to_next_power_of_2(a):
356+
assert a > 0
357+
return 2**int(a - 1).bit_length()
337358

338359
# calculate input shapes
339-
max_num_queries_shifted = shift_bit_length(max_num_queries)
340-
max_num_queries_factor = B_P_SIZE // max_num_queries_shifted
341-
max_num_queries_padded = max_num_queries_shifted * max_num_queries_factor
342-
assert (max_num_queries_padded == B_P_SIZE
343-
), "invalid {max_num_queries_padded=}"
360+
max_num_queries = pad_to_multiple(sum(query_lens), block_size)
361+
max_num_queries = pad_to_next_power_of_2(max_num_queries)
344362
head_size_padded = B_P_SIZE
363+
assert head_size_padded >= head_size
345364
context_lens = torch.tensor(seq_lens) - torch.tensor(query_lens)
346-
num_active_blocks_shifted = shift_bit_length(
347-
((context_lens + block_size - 1) // block_size).sum().item())
348-
num_active_blocks_factor = (LARGE_TILE_SZ // block_size //
349-
num_active_blocks_shifted)
350-
num_active_blocks = num_active_blocks_shifted * num_active_blocks_factor
351-
assert (num_active_blocks *
352-
block_size) == LARGE_TILE_SZ, "invalid {num_active_blocks=}"
365+
num_active_blocks = ceil_div(context_lens, block_size).sum().item()
366+
num_active_blocks = pad_to_multiple(num_active_blocks,
367+
LARGE_TILE_SZ // block_size)
353368
context_kv_len = num_active_blocks * block_size
354-
assert context_kv_len == LARGE_TILE_SZ, f"invalid {context_kv_len=}"
369+
assert (context_kv_len %
370+
LARGE_TILE_SZ == 0), f"invalid context_kv_len={context_kv_len}"
355371

356372
# pad QKV tensors
357373
pad_dims = (
@@ -360,7 +376,7 @@ def shift_bit_length(x):
360376
0,
361377
0,
362378
0,
363-
max_num_queries_padded - query.shape[0],
379+
max_num_queries - query.shape[0],
364380
)
365381
query = F.pad(query, pad_dims, "constant", 0)
366382
k = F.pad(k, pad_dims, "constant", 0)
@@ -397,7 +413,7 @@ def shift_bit_length(x):
397413
0,
398414
context_kv_len - prior_mask.shape[1],
399415
0,
400-
B_P_SIZE - prior_mask.shape[0],
416+
max_num_queries - prior_mask.shape[0],
401417
),
402418
"constant",
403419
0,
@@ -406,9 +422,9 @@ def shift_bit_length(x):
406422
active_mask,
407423
(
408424
0,
409-
B_P_SIZE - active_mask.shape[1],
425+
max_num_queries - active_mask.shape[1],
410426
0,
411-
B_P_SIZE - active_mask.shape[0],
427+
max_num_queries - active_mask.shape[0],
412428
),
413429
"constant",
414430
0,
@@ -430,6 +446,8 @@ def shift_bit_length(x):
430446
n_kv_head=num_kv_heads,
431447
head_size=head_size,
432448
mixed_precision=mixed_precision,
449+
LARGE_TILE_SZ=LARGE_TILE_SZ,
450+
return_debug_tensors=return_debug_tensors,
433451
)
434452

435453
if return_debug_tensors:
@@ -439,17 +457,15 @@ def shift_bit_length(x):
439457
output_nki = flash_attn_varlen_nkifunc(*input_args, **input_kwargs)
440458
debug_tensors = []
441459

442-
output_nki = torch.tensor(output_nki).cpu()
443460
debug_tensors = [torch.tensor(dt).cpu() for dt in debug_tensors]
444461

445462
num_actual_tokens = sum(query_lens)
446-
print(f"{num_actual_tokens=}")
447463
# - o: shape (bs, n_heads, seq_q, d) -> (bs, seq_q, n_heads, d)
448-
output_nki = output_nki.permute(
449-
0, 2, 1, 3)[:, :, :, :head_size].cpu()[0, :num_actual_tokens, :, :]
464+
output_nki = output_nki.cpu().permute(0, 2, 1, 3)[:, :, :, :head_size]
465+
output_nki = output_nki[0, :num_actual_tokens, :, :]
450466
output_ref_padded = F.pad(
451467
output_ref,
452-
(0, 0, 0, 0, 0, 0, 0, max_num_queries_padded - output_ref.shape[0]),
468+
(0, 0, 0, 0, 0, 0, 0, max_num_queries - output_ref.shape[0]),
453469
"constant",
454470
0,
455471
)

0 commit comments

Comments
 (0)