Skip to content

Commit db8f04d

Browse files
authored
hotfix: accelerate plan speed of fa3 template (#690)
The fa3 template's plan speed is very slow because we overestimate the workspace size that needs to be transferred from CPU to GPU, this PR fixes the issue. cc @nandor @zhyncs
1 parent bcf7a3e commit db8f04d

13 files changed

+323
-145
lines changed

aot_build_utils/generate_batch_paged_prefill_sm90_inst.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,19 @@ def get_cu_file_str(
3939
def get_insts(attention_variant):
4040
return "\n".join(
4141
[
42-
"""template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, {attention_variant}>(
42+
"""template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/true, {attention_variant}>(
4343
Params& params,
4444
cudaStream_t stream);
4545
46-
template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, {attention_variant}>(
46+
template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/true, {attention_variant}>(
47+
Params& params,
48+
cudaStream_t stream);
49+
50+
template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/false, {attention_variant}>(
51+
Params& params,
52+
cudaStream_t stream);
53+
54+
template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/false, {attention_variant}>(
4755
Params& params,
4856
cudaStream_t stream);
4957
""".format(

aot_build_utils/generate_batch_ragged_prefill_sm90_inst.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,19 @@ def get_cu_file_str(
4040
def get_insts(attention_variant):
4141
return "\n".join(
4242
[
43-
"""template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, {attention_variant}>(
43+
"""template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/true, {attention_variant}>(
4444
Params& params,
4545
cudaStream_t stream);
4646
47-
template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, {attention_variant}>(
47+
template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/true, {attention_variant}>(
48+
Params& params,
49+
cudaStream_t stream);
50+
51+
template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/false, {attention_variant}>(
52+
Params& params,
53+
cudaStream_t stream);
54+
55+
template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/false, {attention_variant}>(
4856
Params& params,
4957
cudaStream_t stream);
5058
""".format(

csrc/aot_extension_utils.h

-11
Original file line numberDiff line numberDiff line change
@@ -30,17 +30,6 @@
3030
#define DISPATCH_mask_mode(expr, const_expr, ...) \
3131
_DISPATCH_SWITCH("mask_mode", expr, _DISPATCH_CASES_mask_mode(const_expr, __VA_ARGS__))
3232

33-
#define DISPATCH_BOOL(expr, const_expr, ...) \
34-
[&]() -> bool { \
35-
if (expr) { \
36-
constexpr bool const_expr = true; \
37-
return __VA_ARGS__(); \
38-
} else { \
39-
constexpr bool const_expr = false; \
40-
return __VA_ARGS__(); \
41-
} \
42-
}()
43-
4433
#define DISPATCH_PYTORCH_QKV_DTYPE_TO_CTYPE(q_dtype, kv_dtype, c_type_q, c_type_kv, ...) \
4534
[&]() -> bool { \
4635
if (kv_dtype == q_dtype) { \

csrc/batch_prefill_sm90.cu

+39-31
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,14 @@
2929
namespace flashinfer {
3030

3131
template <uint32_t HEAD_DIM, MaskMode MASK_MODE, bool LEFT_SLINDING_WINDOW,
32-
typename AttentionVariant, typename DTypeQ, typename DTypeKV, typename DTypeO,
33-
typename IdType>
32+
bool SAME_SCHEDULE_FOR_ALL_HEADS, typename AttentionVariant, typename DTypeQ,
33+
typename DTypeKV, typename DTypeO, typename IdType>
3434
cudaError_t BatchPrefillWithRaggedKVCacheDispatched(
3535
BatchPrefillRaggedParams<DTypeQ, DTypeKV, DTypeO, IdType>& params, cudaStream_t stream);
3636

3737
template <uint32_t HEAD_DIM, MaskMode MASK_MODE, bool LEFT_SLINDING_WINDOW,
38-
typename AttentionVariant, typename DTypeQ, typename DTypeKV, typename DTypeO,
39-
typename IdType>
38+
bool SAME_SCHEDULE_FOR_ALL_HEADS, typename AttentionVariant, typename DTypeQ,
39+
typename DTypeKV, typename DTypeO, typename IdType>
4040
cudaError_t BatchPrefillWithPagedKVCacheDispatched(
4141
BatchPrefillPagedParams<DTypeQ, DTypeKV, DTypeO, IdType>& params, cudaStream_t stream);
4242

@@ -47,9 +47,9 @@ using namespace flashinfer;
4747
std::vector<int64_t> BatchPrefillWithKVCacheSM90Plan(
4848
unsigned int head_dim, bool causal, at::Tensor float_workspace_buffer,
4949
at::Tensor int_workspace_buffer, at::Tensor page_locked_int_workspace_buffer,
50-
at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor kv_len_arr, unsigned int batch_size,
51-
unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int page_size,
52-
bool enable_cuda_graph, int64_t cuda_stream) {
50+
at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor kv_len_arr, unsigned int total_num_rows,
51+
unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads,
52+
unsigned int page_size, bool enable_cuda_graph, int64_t cuda_stream) {
5353
size_t float_workspace_size_in_bytes =
5454
float_workspace_buffer.size(0) * float_workspace_buffer.element_size();
5555
size_t int_workspace_size_in_bytes =
@@ -61,12 +61,13 @@ std::vector<int64_t> BatchPrefillWithKVCacheSM90Plan(
6161

6262
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
6363

64-
cudaError_t status = PrefillSM90Plan(
65-
float_workspace_buffer.data_ptr(), float_workspace_size_in_bytes,
66-
int_workspace_buffer.data_ptr(), page_locked_int_workspace_buffer.data_ptr(),
67-
int_workspace_size_in_bytes, plan_info, qo_indptr.data_ptr<IdType>(),
68-
kv_indptr.data_ptr<IdType>(), kv_len_arr.data_ptr<IdType>(), batch_size, num_qo_heads,
69-
num_kv_heads, head_dim, page_size, causal, enable_cuda_graph, /*sizeof_dtype_o=*/2, stream);
64+
cudaError_t status =
65+
PrefillSM90Plan(float_workspace_buffer.data_ptr(), float_workspace_size_in_bytes,
66+
int_workspace_buffer.data_ptr(), page_locked_int_workspace_buffer.data_ptr(),
67+
int_workspace_size_in_bytes, plan_info, qo_indptr.data_ptr<IdType>(),
68+
kv_indptr.data_ptr<IdType>(), kv_len_arr.data_ptr<IdType>(), total_num_rows,
69+
batch_size, num_qo_heads, num_kv_heads, head_dim, page_size, causal,
70+
enable_cuda_graph, /*sizeof_dtype_o=*/2, stream);
7071

7172
TORCH_CHECK(status == cudaSuccess,
7273
"PrefillSM90Plan failed with error: ", cudaGetErrorString(status));
@@ -151,19 +152,23 @@ void BatchPrefillWithRaggedKVCacheSM90Run(
151152
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.head_indices_offset);
152153
params.work_indptr = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.work_indptr_offset);
153154

155+
bool same_schedule_for_all_heads = plan_info.same_schedule_for_all_heads;
156+
154157
return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] {
155158
return DISPATCH_mask_mode(mask_mode, MASK_MODE, [&] {
156159
return DISPATCH_BOOL(use_logits_soft_cap, USE_LOGITS_SOFT_CAP, [&] {
157160
return DISPATCH_BOOL(use_swa, USE_SWA, [&] {
158-
using AttentionVariant =
159-
std::conditional_t<USE_LOGITS_SOFT_CAP, LogitsSoftCap, StandardAttention>;
160-
cudaError_t status =
161-
BatchPrefillWithRaggedKVCacheDispatched<HEAD_DIM, MASK_MODE, USE_SWA,
162-
AttentionVariant>(params, stream);
163-
TORCH_CHECK(status == cudaSuccess,
164-
"BatchPrefillWithRaggedKVCacheSM90Run failed with error: ",
165-
cudaGetErrorString(status));
166-
return true;
161+
return DISPATCH_BOOL(same_schedule_for_all_heads, SAME_SCHEDULER_FOR_ALL_HEADS, [&] {
162+
using AttentionVariant =
163+
std::conditional_t<USE_LOGITS_SOFT_CAP, LogitsSoftCap, StandardAttention>;
164+
cudaError_t status = BatchPrefillWithRaggedKVCacheDispatched<
165+
HEAD_DIM, MASK_MODE, USE_SWA, SAME_SCHEDULER_FOR_ALL_HEADS, AttentionVariant>(
166+
params, stream);
167+
TORCH_CHECK(status == cudaSuccess,
168+
"BatchPrefillWithRaggedKVCacheSM90Run failed with error: ",
169+
cudaGetErrorString(status));
170+
return true;
171+
});
167172
});
168173
});
169174
});
@@ -259,20 +264,23 @@ void BatchPrefillWithPagedKVCacheSM90Run(
259264
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.head_indices_offset);
260265
params.work_indptr = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.work_indptr_offset);
261266
params.kv_indices = static_cast<IdType*>(paged_kv_indices.data_ptr());
267+
bool same_schedule_for_all_heads = plan_info.same_schedule_for_all_heads;
262268

263269
return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] {
264270
return DISPATCH_mask_mode(mask_mode, MASK_MODE, [&] {
265271
return DISPATCH_BOOL(use_logits_soft_cap, USE_LOGITS_SOFT_CAP, [&] {
266272
return DISPATCH_BOOL(use_swa, USE_SWA, [&] {
267-
using AttentionVariant =
268-
std::conditional_t<USE_LOGITS_SOFT_CAP, LogitsSoftCap, StandardAttention>;
269-
cudaError_t status =
270-
BatchPrefillWithPagedKVCacheDispatched<HEAD_DIM, MASK_MODE, USE_SWA,
271-
AttentionVariant>(params, stream);
272-
TORCH_CHECK(status == cudaSuccess,
273-
"BatchPrefillWithPagedKVCacheSM90Run failed with error: ",
274-
cudaGetErrorString(status));
275-
return true;
273+
return DISPATCH_BOOL(same_schedule_for_all_heads, SAME_SCHEDULER_FOR_ALL_HEADS, [&] {
274+
using AttentionVariant =
275+
std::conditional_t<USE_LOGITS_SOFT_CAP, LogitsSoftCap, StandardAttention>;
276+
cudaError_t status = BatchPrefillWithPagedKVCacheDispatched<
277+
HEAD_DIM, MASK_MODE, USE_SWA, SAME_SCHEDULER_FOR_ALL_HEADS, AttentionVariant>(
278+
params, stream);
279+
TORCH_CHECK(status == cudaSuccess,
280+
"BatchPrefillWithPagedKVCacheSM90Run failed with error: ",
281+
cudaGetErrorString(status));
282+
return true;
283+
});
276284
});
277285
});
278286
});

csrc/flashinfer_ops_sm90.cu

+3-3
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,9 @@ void single_prefill_with_kv_cache_sm90(unsigned int mask_mode_code, at::Tensor q
3333
std::vector<int64_t> BatchPrefillWithKVCacheSM90Plan(
3434
unsigned int head_dim, bool causal, at::Tensor float_workspace_buffer,
3535
at::Tensor int_workspace_buffer, at::Tensor page_locked_int_workspace_buffer,
36-
at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor kv_len_arr, unsigned int batch_size,
37-
unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int page_size,
38-
bool enable_cuda_graph, int64_t cuda_stream);
36+
at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor kv_len_arr, unsigned int total_num_rows,
37+
unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads,
38+
unsigned int page_size, bool enable_cuda_graph, int64_t cuda_stream);
3939

4040
void BatchPrefillWithRaggedKVCacheSM90Run(
4141
unsigned int mask_mode_code, at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,

csrc/pytorch_extension_utils.h

+11
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,17 @@
189189
return __VA_ARGS__(); \
190190
}
191191

192+
#define DISPATCH_BOOL(expr, const_expr, ...) \
193+
[&]() -> bool { \
194+
if (expr) { \
195+
constexpr bool const_expr = true; \
196+
return __VA_ARGS__(); \
197+
} else { \
198+
constexpr bool const_expr = false; \
199+
return __VA_ARGS__(); \
200+
} \
201+
}()
202+
192203
inline void check_shape(const at::Tensor& a, const at::Tensor& b, const char* a_name,
193204
const char* b_name) {
194205
TORCH_CHECK(a.dim() == b.dim(), a_name, ".dim() != ", b_name, ".dim(). ", a.dim(), " vs ",

0 commit comments

Comments
 (0)