Skip to content

Commit 6e6f38d

Browse files
authored
bugfix: various AOT issues (#752)
1. Add missing instantiation for batch prefill and single prefill. 2. Skip FP8 in sm90 prefill dispatch. 3. Fix the incorrect prefill pybind declaration. 4. Fix mismatched uri for batch prefill 5. Add a DISPATCH_head_dim_sm90 since SM90 only supports 64, 128, 256. 6. Remove `csrc/aot_default_additional_params.h` and add to gitignore.
1 parent e840db1 commit 6e6f38d

13 files changed

+118
-127
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ flashinfer/_build_meta.py
1212
flashinfer/data/
1313
flashinfer/jit/aot_config.py
1414
src/generated/
15+
csrc/aot_default_additional_params.h
1516

1617
# DS_Store files
1718
.DS_store

aot_build_utils/generate.py

-14
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
generate_batch_paged_decode_inst,
2525
generate_batch_paged_prefill_inst,
2626
generate_batch_ragged_prefill_inst,
27-
generate_dispatch_inc,
2827
generate_single_decode_inst,
2928
generate_single_prefill_inst,
3029
)
@@ -48,19 +47,6 @@ def write_if_different(path: Path, content: str) -> None:
4847

4948
path.mkdir(parents=True, exist_ok=True)
5049

51-
# dispatch.inc
52-
write_if_different(
53-
path / "dispatch.inc",
54-
generate_dispatch_inc.get_dispatch_inc_str(
55-
argparse.Namespace(
56-
head_dims=head_dims,
57-
pos_encoding_modes=pos_encoding_modes,
58-
use_fp16_qk_reductions=use_fp16_qk_reductions,
59-
mask_modes=mask_modes,
60-
)
61-
),
62-
)
63-
6450
write_if_different(
6551
path / "aot_default_additional_params.h",
6652
generate_aot_default_additional_params_header.get_aot_default_additional_params_header_str(),

aot_build_utils/generate_batch_paged_decode_inst.py

+12
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,24 @@ def get_cu_file_str(
3535
3636
using Params = BatchDecodeParams<{dtype_q}, {dtype_kv}, {dtype_out}, {idtype}>;
3737
38+
template cudaError_t BatchDecodeWithPagedKVCacheDispatched<{head_dim}, {pos_encoding_mode}, DefaultAttention<
39+
/*use_custom_mask=*/false, /*use_sliding_window=*/false, /*use_logits_soft_cap=*/false, /*use_alibi_bias=*/false>, Params>(
40+
Params params,
41+
{dtype_out}* tmp_v, float* tmp_s,
42+
cudaStream_t stream);
43+
3844
template cudaError_t BatchDecodeWithPagedKVCacheDispatched<{head_dim}, {pos_encoding_mode}, DefaultAttention<
3945
/*use_custom_mask=*/false, /*use_sliding_window=*/true, /*use_logits_soft_cap=*/false, /*use_alibi_bias=*/false>, Params>(
4046
Params params,
4147
{dtype_out}* tmp_v, float* tmp_s,
4248
cudaStream_t stream);
4349
50+
template cudaError_t BatchDecodeWithPagedKVCacheDispatched<{head_dim}, {pos_encoding_mode}, DefaultAttention<
51+
/*use_custom_mask=*/false, /*use_sliding_window=*/false, /*use_logits_soft_cap=*/true, /*use_alibi_bias=*/false>, Params>(
52+
Params params,
53+
{dtype_out}* tmp_v, float* tmp_s,
54+
cudaStream_t stream);
55+
4456
template cudaError_t BatchDecodeWithPagedKVCacheDispatched<{head_dim}, {pos_encoding_mode}, DefaultAttention<
4557
/*use_custom_mask=*/false, /*use_sliding_window=*/true, /*use_logits_soft_cap=*/true, /*use_alibi_bias=*/false>, Params>(
4658
Params params,

aot_build_utils/generate_dispatch_inc.py

+12
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,17 @@ def get_dispatch_inc_str(args: argparse.Namespace) -> str:
3131
dispatch_head_dims_str = f"""#define _DISPATCH_CASES_head_dim(case_var, ...) \\
3232
{dispatch_head_dims_entries}
3333
// EOL
34+
"""
35+
# head dims for sm90
36+
dispatch_head_dims_sm90_entries = "\n".join(
37+
[
38+
" _DISPATCH_CASE({}, case_var, __VA_ARGS__) \\".format(_)
39+
for _ in args.head_dims_sm90
40+
]
41+
)
42+
dispatch_head_dims_sm90_str = f"""#define _DISPATCH_CASES_head_dim_sm90(case_var, ...) \\
43+
{dispatch_head_dims_sm90_entries}
44+
// EOL
3445
"""
3546
# positional encoding modes
3647
dispatch_pos_encoding_modes_entries = "\n".join(
@@ -73,6 +84,7 @@ def get_dispatch_inc_str(args: argparse.Namespace) -> str:
7384
return "\n".join(
7485
[
7586
dispatch_head_dims_str,
87+
dispatch_head_dims_sm90_str,
7688
dispatch_pos_encoding_modes_str,
7789
dispatch_use_fp16_qk_reductions_str,
7890
dispatch_mask_mode_str,

aot_build_utils/generate_single_decode_inst.py

+13
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,18 @@ def get_cu_file_str(
3434
3535
using Params = SingleDecodeParams<{dtype_q}, {dtype_kv}, {dtype_out}>;
3636
37+
template cudaError_t SingleDecodeWithKVCacheDispatched<{head_dim}, {pos_encoding_mode}, DefaultAttention<
38+
/*use_custom_mask=*/false, /*use_sliding_window=*/false, /*use_logits_soft_cap=*/false, /*use_alibi_bias=*/false>, Params>(
39+
Params params,
40+
{dtype_out}* tmp,
41+
cudaStream_t stream);
42+
43+
template cudaError_t SingleDecodeWithKVCacheDispatched<{head_dim}, {pos_encoding_mode}, DefaultAttention<
44+
/*use_custom_mask=*/false, /*use_sliding_window=*/false, /*use_logits_soft_cap=*/true, /*use_alibi_bias=*/false>, Params>(
45+
Params params,
46+
{dtype_out}* tmp,
47+
cudaStream_t stream);
48+
3749
template cudaError_t SingleDecodeWithKVCacheDispatched<{head_dim}, {pos_encoding_mode}, DefaultAttention<
3850
/*use_custom_mask=*/false, /*use_sliding_window=*/true, /*use_logits_soft_cap=*/false, /*use_alibi_bias=*/false>, Params>(
3951
Params params,
@@ -45,6 +57,7 @@ def get_cu_file_str(
4557
Params params,
4658
{dtype_out}* tmp,
4759
cudaStream_t stream);
60+
4861
}}
4962
""".format(
5063
head_dim=head_dim,

csrc/aot_default_additional_params.h

-66
This file was deleted.

csrc/aot_extension_utils.h

+3
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
#define DISPATCH_head_dim(expr, const_expr, ...) \
2020
_DISPATCH_SWITCH("head_dim", expr, _DISPATCH_CASES_head_dim(const_expr, __VA_ARGS__))
2121

22+
#define DISPATCH_head_dim_sm90(expr, const_expr, ...) \
23+
_DISPATCH_SWITCH("head_dim", expr, _DISPATCH_CASES_head_dim_sm90(const_expr, __VA_ARGS__))
24+
2225
#define DISPATCH_pos_encoding_mode(expr, const_expr, ...) \
2326
_DISPATCH_SWITCH("positional encoding mode", expr, \
2427
_DISPATCH_CASES_pos_encoding_mode(const_expr, __VA_ARGS__))

csrc/batch_prefill_sm90_config.inc

+17-15
Original file line numberDiff line numberDiff line change
@@ -31,22 +31,24 @@ using IdType = int32_t;
3131
USE_LOGITS_SOFT_CAP, AttentionVariant, RaggedParams, PagedParams, ...) \
3232
{ \
3333
DISPATCH_mask_mode(mask_mode, MASK_MODE, [&] { \
34-
return DISPATCH_PYTORCH_QKV_DTYPE_TO_CTYPE( \
35-
q_scalar_type, kv_scalar_type, dtype_q, dtype_kv, [&] { \
36-
using DTypeQ = cutlass_dtype_t<dtype_q>; \
37-
using DTypeKV = cutlass_dtype_t<dtype_kv>; \
38-
using DTypeO = DTypeQ; \
39-
using RaggedParams = BatchPrefillRaggedParams<DTypeQ, DTypeKV, DTypeO, IdType>; \
40-
using PagedParams = BatchPrefillPagedParams<DTypeQ, DTypeKV, DTypeO, IdType>; \
41-
return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { \
42-
return DISPATCH_BOOL(window_left > -1, USE_SLIDING_WINDOW, [&] { \
43-
return DISPATCH_BOOL(logits_soft_cap > 0.f, USE_LOGITS_SOFT_CAP, [&] { \
44-
using AttentionVariant = DefaultAttention<USE_LOGITS_SOFT_CAP>; \
45-
__VA_ARGS__(); \
46-
return true; \
47-
}); \
48-
}); \
34+
if (q_scalar_type != kv_scalar_type) { \
35+
return false; \
36+
} \
37+
return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q_scalar_type, dtype_q, [&] { \
38+
using DTypeQ = cutlass_dtype_t<dtype_q>; \
39+
using DTypeKV = DTypeQ; \
40+
using DTypeO = DTypeQ; \
41+
using RaggedParams = BatchPrefillRaggedParams<DTypeQ, DTypeKV, DTypeO, IdType>; \
42+
using PagedParams = BatchPrefillPagedParams<DTypeQ, DTypeKV, DTypeO, IdType>; \
43+
return DISPATCH_head_dim_sm90(head_dim, HEAD_DIM, [&] { \
44+
return DISPATCH_BOOL(window_left > -1, USE_SLIDING_WINDOW, [&] { \
45+
return DISPATCH_BOOL(logits_soft_cap > 0.f, USE_LOGITS_SOFT_CAP, [&] { \
46+
using AttentionVariant = DefaultAttention<USE_LOGITS_SOFT_CAP>; \
47+
__VA_ARGS__(); \
48+
return true; \
4949
}); \
5050
}); \
51+
}); \
52+
}); \
5153
}); \
5254
}

csrc/flashinfer_ops.cu

+3-3
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,9 @@ void single_prefill_with_kv_cache(at::Tensor q, at::Tensor k, at::Tensor v, at::
105105
std::vector<int64_t> BatchPrefillWithKVCachePlan(
106106
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
107107
at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr,
108-
unsigned total_num_rows, unsigned int batch_size, unsigned int num_qo_heads,
109-
unsigned int num_kv_heads, unsigned int page_size, bool enable_cuda_graph,
110-
unsigned int head_dim, bool causal, int64_t cuda_stream);
108+
at::Tensor kv_len_arr, unsigned total_num_rows, unsigned int batch_size,
109+
unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int page_size,
110+
bool enable_cuda_graph, unsigned int head_dim, bool causal, int64_t cuda_stream);
111111

112112
void BatchPrefillWithRaggedKVCacheRun(
113113
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,

csrc/flashinfer_ops_sm90.cu

+7-7
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,16 @@ void CutlassSegmentGEMMSM90(at::Tensor float_workspace_buffer, at::Tensor int_wo
2323
int64_t cuda_stream);
2424

2525
void single_prefill_with_kv_cache_sm90(
26-
at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor o, std::optional<at::Tensor> maybe_lse,
27-
unsigned int mask_mode_code, unsigned int layout,
26+
at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor tmp, at::Tensor o,
27+
std::optional<at::Tensor> maybe_lse, unsigned int mask_mode_code, unsigned int layout,
2828
int32_t window_left SINGLE_PREFILL_SM90_ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream);
2929

3030
std::vector<int64_t> BatchPrefillWithKVCacheSM90Plan(
31-
unsigned int head_dim, bool causal, at::Tensor float_workspace_buffer,
32-
at::Tensor int_workspace_buffer, at::Tensor page_locked_int_workspace_buffer,
33-
at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor kv_len_arr, unsigned int total_num_rows,
34-
unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads,
35-
unsigned int page_size, bool enable_cuda_graph, int64_t cuda_stream);
31+
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
32+
at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr,
33+
at::Tensor kv_len_arr, unsigned total_num_rows, unsigned int batch_size,
34+
unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int page_size,
35+
bool enable_cuda_graph, unsigned int head_dim, bool causal, int64_t cuda_stream);
3636

3737
void BatchPrefillWithRaggedKVCacheSM90Run(
3838
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,

csrc/single_prefill_sm90_config.inc

+16-14
Original file line numberDiff line numberDiff line change
@@ -31,21 +31,23 @@ using IdType = int32_t;
3131
USE_LOGITS_SOFT_CAP, AttentionVariant, Params, ...) \
3232
{ \
3333
DISPATCH_mask_mode(mask_mode, MASK_MODE, [&] { \
34-
return DISPATCH_PYTORCH_QKV_DTYPE_TO_CTYPE( \
35-
q_scalar_type, kv_scalar_type, dtype_q, dtype_kv, [&] { \
36-
using DTypeQ = cutlass_dtype_t<dtype_q>; \
37-
using DTypeKV = cutlass_dtype_t<dtype_kv>; \
38-
using DTypeO = DTypeQ; \
39-
using Params = SinglePrefillParams<DTypeQ, DTypeKV, DTypeO>; \
40-
return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { \
41-
return DISPATCH_BOOL(window_left > -1, USE_SLIDING_WINDOW, [&] { \
42-
return DISPATCH_BOOL(logits_soft_cap > 0.f, USE_LOGITS_SOFT_CAP, [&] { \
43-
using AttentionVariant = DefaultAttention<USE_LOGITS_SOFT_CAP>; \
44-
__VA_ARGS__(); \
45-
return true; \
46-
}); \
47-
}); \
34+
if (q_scalar_type != kv_scalar_type) { \
35+
return false; \
36+
} \
37+
return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q_scalar_type, dtype_q, [&] { \
38+
using DTypeQ = cutlass_dtype_t<dtype_q>; \
39+
using DTypeKV = DTypeQ; \
40+
using DTypeO = DTypeQ; \
41+
using Params = SinglePrefillParams<DTypeQ, DTypeKV, DTypeO>; \
42+
return DISPATCH_head_dim_sm90(head_dim, HEAD_DIM, [&] { \
43+
return DISPATCH_BOOL(window_left > -1, USE_SLIDING_WINDOW, [&] { \
44+
return DISPATCH_BOOL(logits_soft_cap > 0.f, USE_LOGITS_SOFT_CAP, [&] { \
45+
using AttentionVariant = DefaultAttention<USE_LOGITS_SOFT_CAP>; \
46+
__VA_ARGS__(); \
47+
return true; \
4848
}); \
4949
}); \
50+
}); \
51+
}); \
5052
}); \
5153
}

flashinfer/jit/attention.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def get_batch_prefill_uri(
249249
use_fp16_qk_reduction: bool,
250250
) -> str:
251251
return (
252-
f"batch_prefill_{backend}_with_kv_cache_dtype_q_{filename_safe_dtype_map[dtype_q]}_"
252+
f"batch_prefill_with_kv_cache_dtype_q_{filename_safe_dtype_map[dtype_q]}_"
253253
f"dtype_kv_{filename_safe_dtype_map[dtype_kv]}_"
254254
f"dtype_o_{filename_safe_dtype_map[dtype_o]}_"
255255
f"dtype_idx_{filename_safe_dtype_map[dtype_idx]}_"

0 commit comments

Comments
 (0)