Skip to content

[Kernel][ROCM] Upstream prefix prefill speed up for vLLM V1 #13305

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 73 commits into from
Apr 23, 2025
Merged
Changes from 55 commits
Commits
Show all changes
73 commits
Select commit Hold shift + click to select a range
b6b00d7
init
SageMoore Feb 5, 2025
fa52268
temporarily remove torch from requirements-build
SageMoore Feb 5, 2025
f563276
move rocm logic to its own attention backend
SageMoore Feb 6, 2025
2a03b92
actually add backend
SageMoore Feb 6, 2025
4bdf7de
more rocm refactoring
SageMoore Feb 7, 2025
875fcfc
Merge branch 'main' of https://github.com/neuralmagic/vllm into sage/…
SageMoore Feb 7, 2025
e507e30
more rocm refactoring
SageMoore Feb 7, 2025
b9ce259
hack to fix the multiprocessing isssue
SageMoore Feb 7, 2025
f2cc5e3
minor print fix
SageMoore Feb 7, 2025
d6f6c5c
remove cruft
SageMoore Feb 7, 2025
2bf214a
format
SageMoore Feb 7, 2025
11411cb
modify requirements files
SageMoore Feb 7, 2025
c2499bf
remove basic.py changes
SageMoore Feb 7, 2025
cf6f691
cleanup
SageMoore Feb 7, 2025
4505f53
add support for passing in softmax scales to the context_attn_fwd
SageMoore Feb 7, 2025
9a0416a
Merge branch 'main' of https://github.com/neuralmagic/vllm into sage/…
SageMoore Feb 7, 2025
ef9ae86
added requirements-rocm-build
SageMoore Feb 10, 2025
0ccef65
Merge branch 'main' of https://github.com/neuralmagic/vllm into sage/…
SageMoore Feb 10, 2025
a00a2d9
minor setup.py fix
SageMoore Feb 10, 2025
afb15f5
add batch size back in
SageMoore Feb 10, 2025
08a25b7
revert setup.py change
SageMoore Feb 10, 2025
55eb036
update setup.py
SageMoore Feb 10, 2025
95df571
init
SageMoore Feb 10, 2025
0bfe435
init
SageMoore Feb 11, 2025
4b62de2
Merge branch 'main' of https://github.com/neuralmagic/vllm into sage/…
SageMoore Feb 11, 2025
d2f3c85
minor fix
SageMoore Feb 11, 2025
442bc7b
Merge branch 'main' of https://github.com/neuralmagic/vllm into sage/…
SageMoore Feb 12, 2025
9472636
minor fix
SageMoore Feb 12, 2025
c7497f3
Merge branch 'main' of https://github.com/neuralmagic/vllm into sage/…
SageMoore Feb 12, 2025
21d8d6a
update error messages
SageMoore Feb 12, 2025
a1cac3d
init
SageMoore Feb 13, 2025
40a64d7
Merge branch 'neuralmagic_sage_amd-v1' into upstream_prefix_prefill_s…
Feb 13, 2025
329ad79
Merge branch 'neuralmagic_sage_prefix-prefill-refactor' into upstream…
Feb 13, 2025
ad1db61
Merge branch 'neuralmagic_sage_rocm-fp4-fix' into upstream_prefix_pre…
Feb 13, 2025
c02b1e6
new prefix_prefill
Feb 14, 2025
540b286
dwordx4 for k and v from cache
Feb 20, 2025
4784522
merge with main
Feb 20, 2025
1d9eb50
follow up merge with main
Feb 20, 2025
9eb5566
different stages for different loops
Feb 22, 2025
0afb796
merge with main
Feb 24, 2025
9bc9217
unroll factors tunning
Feb 25, 2025
2b84448
linter
Feb 25, 2025
1067508
default prefix_prefill for triton lower than 3.2, NV case
Feb 25, 2025
fb21239
Merge branch 'upstream/main' into upstream_prefix_prefill_speed_up
qli88 Feb 28, 2025
3b99cf7
Merge branch 'upstream/main' into upstream_prefix_prefill_speed_up
Feb 28, 2025
506b0c4
original softmax restored to get back accuracy
Mar 4, 2025
1dc5142
merge with main
Mar 4, 2025
e044108
Merge branch 'upstream/main' into upstream_prefix_prefill_speed_up
Mar 10, 2025
05c3d3b
adaptation to ibm kernel
Mar 10, 2025
e76f27f
softmax computation correction
Mar 11, 2025
da80a03
a comment for triton version
Mar 13, 2025
83a86a8
Merge branch 'upstream_prefix_prefill_speed_up' of github.com:ROCm/vl…
Mar 19, 2025
1369809
Merge branch 'upstream/main' into upstream_prefix_prefill_speed_up
Mar 19, 2025
81277c8
kpack is not supported on NVidia triton
Mar 19, 2025
a4000df
kpack is not supported on NVidia triton
Mar 19, 2025
a027e5c
reduced space of autotuning
Mar 27, 2025
db608bb
Merge branch 'upstream/main' into upstream_prefix_prefill_speed_up
Apr 6, 2025
81c2739
giving up on autotune and selecting one config
Apr 8, 2025
7add0e2
Merge branch 'upstream/main' into upstream_prefix_prefill_speed_up
Apr 8, 2025
5a17950
fixing test with only to ROCM waves per eu and max_seq_len None
Apr 8, 2025
5d9a929
renaming kernel
Apr 9, 2025
27f044b
clean up and fix for failed kernel tests
Apr 10, 2025
cfd60c9
clean up and fix for failed kernel tests
Apr 10, 2025
0a26697
clean up and fix for failed kernel tests
Apr 10, 2025
35a6e49
got rid of autotuner and get stable runs right from the first iteration
Apr 11, 2025
6d5b3f2
restoring paged attn as there is no autotuning anymore and that will …
Apr 12, 2025
7140d1a
poking test rerun as one failed and seems not because of this change
Apr 13, 2025
169f714
Merge branch 'main' of github.com:vllm-project/vllm into upstream_pre…
Apr 14, 2025
f437b11
Merge branch 'upstream/main' into upstream_prefix_prefill_speed_up
Apr 14, 2025
ba078b6
comment correction
Apr 14, 2025
617ef08
dot operation in triton doesn't support k to be 8 so increasing block…
Apr 15, 2025
771ad9e
to kick CIs again Async Engine, Inputs, Utils, Worker Test seems flaky
Apr 15, 2025
b6bf365
to kick CIs again
Apr 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
292 changes: 287 additions & 5 deletions vllm/attention/ops/prefix_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,228 @@ def _fwd_kernel(
(offs_m[:, None] < cur_batch_query_len))
return

# On triton versions lower 3.2 the assertion:
# Assertion `!(srcMmaLayout && dstMmaLayout && !srcMmaLayout.isAmpere()) &&
# "mma -> mma layout conversion is only supported on Ampere"' failed.
# is observed
if triton.__version__ >= "3.2.0":
@triton.autotune(
configs=[
triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, \
"num_unroll_cache": num_unroll_cache, \
"num_unroll_request": num_unroll_request } | \
({"kpack": 2} \
if current_platform.is_rocm() else {}), \
num_warps=num_warps) \
for block_m in [32, 64, 128] for block_n in [32, 64, 128] \
for num_warps in [4, 8] for num_unroll_cache in [1, 2] \
for num_unroll_request in [1, 2]
],
key=["BLOCK_M", "BLOCK_N", "BLOCK_SIZE", \
"BLOCK_DMODEL_PADDED", "BLOCK_DMODEL"]
)
@triton.jit
def _fwd_kernel(
Q, K, V, K_cache, V_cache, B_Loc, sm_scale, k_scale, v_scale,
B_Start_Loc, B_Seqlen, x: tl.constexpr, Out, stride_b_loc_b,
stride_b_loc_s, stride_qbs, stride_qh, stride_qd, stride_kbs,
stride_kh, stride_kd, stride_vbs, stride_vh, stride_vd,
stride_obs, stride_oh, stride_od, stride_k_cache_bs,
stride_k_cache_h, stride_k_cache_d,
stride_k_cache_bl: tl.constexpr, stride_k_cache_x,
stride_v_cache_bs, stride_v_cache_h, stride_v_cache_d,
stride_v_cache_bl, num_queries_per_kv: tl.constexpr,
IN_PRECISION: tl.constexpr, BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr, BLOCK_DMODEL_PADDED: tl.constexpr,
BLOCK_SIZE: tl.constexpr, BLOCK_N: tl.constexpr,
SLIDING_WINDOW: tl.constexpr, num_unroll_cache: tl.constexpr,
num_unroll_request: tl.constexpr, SKIP_DECODE: tl.constexpr):

cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
start_m = tl.program_id(2)

cur_kv_head = cur_head // num_queries_per_kv

cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1)
cur_batch_query_len = (cur_batch_in_all_stop_index -
cur_batch_in_all_start_index)
cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len

if SKIP_DECODE and cur_batch_query_len == 1:
return

# start position inside of the query
# generally, N goes over kv, while M goes over query_len
block_start_loc = BLOCK_M * start_m

# initialize offsets
# [BLOCK_SIZE]; starts at 0
offs_bs_n = tl.arange(0, BLOCK_SIZE)
# [N]; starts at 0
offs_n = tl.arange(0, BLOCK_N)
# [D]; starts at 0
offs_d = tl.arange(0, BLOCK_DMODEL_PADDED)
# [M]; starts at current position in query
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
# [M,D]
off_q = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
cur_head * stride_qh + offs_d[None, :] * stride_qd)

dim_mask = tl.where(
tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1,
0).to(tl.int1) # [D]

q = tl.load(Q + off_q,
mask=dim_mask[None, :] &
(offs_m[:, None] < cur_batch_query_len),
other=0.0) # [M,D]

# initialize pointer to m and l
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED],
dtype=tl.float32) # [M,D]

# compute query against context (no causal mask here)
for start_n in tl.range(0, cur_batch_ctx_len, BLOCK_SIZE, \
loop_unroll_factor=num_unroll_cache):
start_n = tl.multiple_of(start_n, BLOCK_SIZE)
# -- compute qk ----
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
(start_n // BLOCK_SIZE) * stride_b_loc_s)
# [D,BLOCK_SIZE]
off_k = (bn[None, :] * stride_k_cache_bs +
cur_kv_head * stride_k_cache_h +
(offs_d[:, None] // x) * stride_k_cache_d +
((start_n + offs_bs_n[None, :]) % BLOCK_SIZE) *
stride_k_cache_bl +
(offs_d[:, None] % x) * stride_k_cache_x)

# [BLOCK_SIZE,D]
off_v = (bn[:, None] * stride_v_cache_bs +
cur_kv_head * stride_v_cache_h +
offs_d[None, :] * stride_v_cache_d +
offs_bs_n[:, None] * stride_v_cache_bl)
k_load = tl.load(K_cache + off_k)

if k_load.dtype.is_fp8():
k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype)
else:
k = k_load

qk = tl.zeros([BLOCK_M, BLOCK_SIZE], dtype=tl.float32) # [M,N]
qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
qk = tl.where((start_n + offs_bs_n[None, :])
< cur_batch_ctx_len, qk, float("-inf"))
qk *= sm_scale
if SLIDING_WINDOW > 0:
# (cur_batch_ctx_len + offs_m[:, None]) are the positions of
# Q entries in sequence
# (start_n + offs_bs_n[None, :]) are the positions of
# KV entries in sequence
# So the condition makes sure each entry in Q only attends
# to KV entries not more than SLIDING_WINDOW away.
#
# We can't use -inf here, because the
# sliding window may lead to the entire row being masked.
# This then makes m_ij contain -inf, which causes NaNs in
# exp().
qk = tl.where((cur_batch_ctx_len + offs_m[:, None]) -
(start_n + offs_bs_n[None, :])
< SLIDING_WINDOW, qk, -10000)

# compute running maximum
m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
p = tl.exp(qk - m_ij[:, None])
l_ij = tl.sum(p, axis=1)
alpha = tl.exp(m_i - m_ij)
acc = acc * alpha[:, None]

# update acc
v_load = tl.load(V_cache + off_v)
if v_load.dtype.is_fp8():
v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype)
else:
v = v_load
p = p.to(v.dtype)

acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION)
# # update m_i and l_i
l_i = l_i * alpha + l_ij
m_i = m_ij

off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
offs_d[:, None] * stride_kd)
off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
offs_d[None, :] * stride_vd)
k_ptrs = K + off_k
v_ptrs = V + off_v

# block_mask is 0 when we're already past the current query length
block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0)

# compute query against itself (with causal mask)
for start_n in tl.range(0, \
block_mask * (start_m + 1) * BLOCK_M, BLOCK_N, \
loop_unroll_factor=num_unroll_request):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(
k_ptrs +
(cur_batch_in_all_start_index + start_n) * stride_kbs,
mask=dim_mask[:, None] &
((start_n + offs_n[None, :]) < cur_batch_query_len),
other=0.0)

qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
qk *= sm_scale
# apply causal mask
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]),
qk, float("-inf"))
if SLIDING_WINDOW > 0:
qk = tl.where(
offs_m[:, None] - (start_n + offs_n[None, :])
< SLIDING_WINDOW, qk, -10000)

# compute running maximum
m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
p = tl.exp(qk - m_ij[:, None])
l_ij = tl.sum(p, axis=1)
alpha = tl.exp(m_i - m_ij)
acc = acc * alpha[:, None]

# update acc
v = tl.load(
v_ptrs +
(cur_batch_in_all_start_index + start_n) * stride_vbs,
mask=dim_mask[None, :] &
((start_n + offs_n[:, None]) < cur_batch_query_len),
other=0.0)
p = p.to(v.dtype)

acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION)
# update m_i and l_i
l_i = l_i * alpha + l_ij
m_i = m_ij

acc = acc / l_i[:, None]

# initialize pointers to output
off_o = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
cur_head * stride_oh + offs_d[None, :] * stride_od)
out_ptrs = Out + off_o
tl.store(out_ptrs,
acc,
mask=dim_mask[None, :] &
(offs_m[:, None] < cur_batch_query_len))
return

@triton.jit
def _fwd_kernel_flash_attn_v2(
Q,
Expand Down Expand Up @@ -734,10 +956,6 @@ def context_attention_fwd(q,
skip_decode=False):

q_dtype_is_f32 = q.dtype is torch.float32
# need to reduce num. blocks when using fp32
# due to increased use of GPU shared memory
# if q.dtype is torch.float32:
BLOCK = BASE_BLOCK // 2 if q_dtype_is_f32 else BASE_BLOCK

# Turing does have tensor core for float32 multiplication
# use ieee as fallback for triton kernels work. There is also
Expand Down Expand Up @@ -778,13 +996,18 @@ def context_attention_fwd(q,
num_queries_per_kv = q.shape[1] // k.shape[1]

assert batch + 1 == len(b_start_loc)
grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head,

# 0 means "disable"
if sliding_window is None or sliding_window <= 0:
sliding_window = 0

if alibi_slopes is not None:
# need to reduce num. blocks when using fp32
# due to increased use of GPU shared memory
# if q.dtype is torch.float32:
BLOCK = BASE_BLOCK // 2 if q_dtype_is_f32 else BASE_BLOCK
# batch, head,
grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
_fwd_kernel_alibi[grid](
q,
k,
Expand Down Expand Up @@ -839,6 +1062,65 @@ def context_attention_fwd(q,
)
return

if triton.__version__ >= "3.2.0":
grid = lambda META: (batch, head,
triton.cdiv(max_input_len, META["BLOCK_M"]))
_fwd_kernel[grid](
q,
k,
v,
k_cache,
v_cache,
b_loc,
sm_scale,
k_scale,
v_scale,
b_start_loc,
b_seq_len,
k_cache.shape[4],
o,
b_loc.stride(0),
b_loc.stride(1),
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
o.stride(0),
o.stride(1),
o.stride(2),
k_cache.stride(0),
k_cache.stride(1),
k_cache.stride(2),
k_cache.stride(3),
k_cache.stride(
4
), #[num_blocks, num_kv_heads, head_size/x, block_size, x]
v_cache.stride(0),
v_cache.stride(1),
v_cache.stride(2),
v_cache.stride(
3), #[num_blocks, num_kv_heads, head_size, block_size]
BLOCK_SIZE=v_cache.shape[3],
num_queries_per_kv=num_queries_per_kv,
IN_PRECISION=IN_PRECISION,
BLOCK_DMODEL=Lk,
BLOCK_DMODEL_PADDED=Lk_padded,
SLIDING_WINDOW=sliding_window,
SKIP_DECODE=skip_decode,
)
return

# need to reduce num. blocks when using fp32
# due to increased use of GPU shared memory
# if q.dtype is torch.float32:
BLOCK = BASE_BLOCK // 2 if q_dtype_is_f32 else BASE_BLOCK
# batch, head,
grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
_fwd_kernel[grid](
q,
k,
Expand Down