Skip to content

Commit dbb1e4e

Browse files
abmfyyoukaichaoyzh119
authored
refactor: change to TORCH_LIBRARY (#823)
This PR updates FlashInfer's C++/CUDA extensions from pybind11 modules to `torch.libraries`, which is recommended since PyTorch 2.5. This is mainly implemented in #764. We have investigated that the issue in #820 was not caused by this PR, so we're opening it up again. --------- Signed-off-by: youkaichao <[email protected]> Signed-off-by: abmfy <[email protected]> Co-authored-by: youkaichao <[email protected]> Co-authored-by: Zihao Ye <[email protected]>
1 parent c716aed commit dbb1e4e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+503
-386
lines changed

aot_build_utils/generate_aot_default_additional_params_header.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def get_aot_default_additional_params_header_str() -> str:
8585
"rope_rcp_scale",
8686
"rope_rcp_theta",
8787
], # additional_scalar_names
88-
["float", "float", "float", "float"], # additional_scalar_dtypes
88+
["double", "double", "double", "double"], # additional_scalar_dtypes
8989
)
9090

9191
ret += generate_macro_entry(
@@ -98,15 +98,15 @@ def get_aot_default_additional_params_header_str() -> str:
9898
"rope_rcp_scale",
9999
"rope_rcp_theta",
100100
],
101-
["float", "float", "float", "float"],
101+
["double", "double", "double", "double"],
102102
)
103103

104104
ret += generate_macro_entry(
105105
"SINGLE_PREFILL_SM90",
106106
[],
107107
[],
108108
["logits_soft_cap", "sm_scale"],
109-
["float", "float"],
109+
["double", "double"],
110110
is_sm90_template=True,
111111
)
112112

@@ -120,7 +120,7 @@ def get_aot_default_additional_params_header_str() -> str:
120120
"rope_rcp_scale",
121121
"rope_rcp_theta",
122122
], # additional_scalar_names
123-
["float", "float", "float", "float"], # additional_scalar_dtypes
123+
["double", "double", "double", "double"], # additional_scalar_dtypes
124124
)
125125

126126
ret += generate_macro_entry(
@@ -133,15 +133,15 @@ def get_aot_default_additional_params_header_str() -> str:
133133
"rope_rcp_scale",
134134
"rope_rcp_theta",
135135
],
136-
["float", "float", "float", "float"],
136+
["double", "double", "double", "double"],
137137
)
138138

139139
ret += generate_macro_entry(
140140
"BATCH_PREFILL_SM90",
141141
[],
142142
[],
143143
["logits_soft_cap", "sm_scale"],
144-
["float", "float"],
144+
["double", "double"],
145145
is_sm90_template=True,
146146
)
147147

csrc/batch_decode.cu

+10-10
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
#include "batch_decode_config.inc"
2222
#include "pytorch_extension_utils.h"
23+
#include "pytorch_conversion_utils.h"
2324

2425
namespace flashinfer {
2526

@@ -32,13 +33,12 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched(Params params, typename Params
3233

3334
using namespace flashinfer;
3435

35-
std::vector<int64_t> BatchDecodeWithPagedKVCachePlan(
36+
at::Tensor BatchDecodeWithPagedKVCachePlan(
3637
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
37-
at::Tensor page_locked_int_workspace_buffer, at::Tensor indptr, unsigned int batch_size,
38-
unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int page_size,
39-
bool enable_cuda_graph, int window_left, float logits_soft_cap, unsigned int head_dim_qk,
40-
unsigned int head_dim_vo, at::Tensor empty_q_data, at::Tensor empty_kv_data,
41-
int64_t cuda_stream) {
38+
at::Tensor page_locked_int_workspace_buffer, at::Tensor indptr, int64_t batch_size,
39+
int64_t num_qo_heads, int64_t num_kv_heads, int64_t page_size,
40+
bool enable_cuda_graph, int64_t window_left, double logits_soft_cap, int64_t head_dim_qk,
41+
int64_t head_dim_vo, at::Tensor empty_q_data, at::Tensor empty_kv_data, int64_t cuda_stream) {
4242
size_t float_workspace_size_in_bytes =
4343
float_workspace_buffer.size(0) * float_workspace_buffer.element_size();
4444
size_t int_workspace_size_in_bytes =
@@ -74,17 +74,17 @@ std::vector<int64_t> BatchDecodeWithPagedKVCachePlan(
7474
});
7575
});
7676

77-
return plan_info.ToVector();
77+
return vec_to_tensor(plan_info.ToVector());
7878
}
7979

8080
void BatchDecodeWithPagedKVCacheRun(
8181
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
82-
std::vector<int64_t> plan_info_vec, at::Tensor q, at::Tensor paged_k_cache,
82+
at::Tensor plan_info_vec, at::Tensor q, at::Tensor paged_k_cache,
8383
at::Tensor paged_v_cache, at::Tensor paged_kv_indptr, at::Tensor paged_kv_indices,
8484
at::Tensor paged_kv_last_page_len, at::Tensor o, std::optional<at::Tensor> maybe_lse,
85-
unsigned int kv_layout_code, int window_left ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream) {
85+
int64_t kv_layout_code, int64_t window_left ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream) {
8686
DecodePlanInfo plan_info;
87-
plan_info.FromVector(plan_info_vec);
87+
plan_info.FromVector(tensor_to_vec(plan_info_vec));
8888
QKVLayout kv_layout = static_cast<QKVLayout>(kv_layout_code);
8989
auto device = q.device();
9090
int64_t batch_size = q.size(0);

csrc/batch_decode_jit_pybind.cu

+12-11
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,23 @@
1616
#include "batch_decode_config.inc"
1717
#include "pytorch_extension_utils.h"
1818

19-
std::vector<int64_t> BatchDecodeWithPagedKVCachePlan(
19+
at::Tensor BatchDecodeWithPagedKVCachePlan(
2020
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
21-
at::Tensor page_locked_int_workspace_buffer, at::Tensor indptr, unsigned int batch_size,
22-
unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int page_size,
23-
bool enable_cuda_graph, int window_left, float logits_soft_cap, unsigned int head_dim_qk,
24-
unsigned int head_dim_vo, at::Tensor empty_q_data, at::Tensor empty_kv_data,
25-
int64_t cuda_stream);
21+
at::Tensor page_locked_int_workspace_buffer, at::Tensor indptr, int64_t batch_size,
22+
int64_t num_qo_heads, int64_t num_kv_heads, int64_t page_size,
23+
bool enable_cuda_graph, int64_t window_left, double logits_soft_cap, int64_t head_dim_qk,
24+
int64_t head_dim_vo, at::Tensor empty_q_data, at::Tensor empty_kv_data, int64_t cuda_stream);
2625

2726
void BatchDecodeWithPagedKVCacheRun(
2827
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
29-
std::vector<int64_t> plan_info_vec, at::Tensor q, at::Tensor paged_k_cache,
28+
at::Tensor plan_info_vec, at::Tensor q, at::Tensor paged_k_cache,
3029
at::Tensor paged_v_cache, at::Tensor paged_kv_indptr, at::Tensor paged_kv_indices,
3130
at::Tensor paged_kv_last_page_len, at::Tensor o, std::optional<at::Tensor> maybe_lse,
32-
unsigned int kv_layout_code, int window_left ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream);
31+
int64_t kv_layout_code, int64_t window_left ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream);
3332

34-
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
35-
m.def("plan", &BatchDecodeWithPagedKVCachePlan, "Batched decode with paged KV-Cache plan");
36-
m.def("run", &BatchDecodeWithPagedKVCacheRun, "Batched decode with paged KV-Cache run");
33+
TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) {
34+
// Batched decode with paged KV-Cache plan
35+
m.def("plan", BatchDecodeWithPagedKVCachePlan);
36+
// Batched decode with paged KV-Cache run
37+
m.def("run", BatchDecodeWithPagedKVCacheRun);
3738
}

csrc/batch_decode_mla_plan.cu

+5-4
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@
44

55
#include "mla_config.inc"
66
#include "pytorch_extension_utils.h"
7+
#include "pytorch_conversion_utils.h"
78

89
using namespace flashinfer;
910

10-
std::vector<int64_t> BatchDecodeWithPagedKVCachePlanMLA(
11+
at::Tensor BatchDecodeWithPagedKVCachePlanMLA(
1112
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
12-
at::Tensor page_locked_int_workspace_buffer, at::Tensor indptr, unsigned int batch_size,
13-
unsigned int num_qo_heads, unsigned int page_size, bool enable_cuda_graph,
13+
at::Tensor page_locked_int_workspace_buffer, at::Tensor indptr, int64_t batch_size,
14+
int64_t num_qo_heads, int64_t page_size, bool enable_cuda_graph,
1415
int64_t cuda_stream) {
1516
size_t float_workspace_size_in_bytes =
1617
float_workspace_buffer.size(0) * float_workspace_buffer.element_size();
@@ -35,5 +36,5 @@ std::vector<int64_t> BatchDecodeWithPagedKVCachePlanMLA(
3536
TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPagedKVCachePlanMLA failed with error ",
3637
cudaGetErrorString(status));
3738

38-
return plan_info.ToVector();
39+
return vec_to_tensor(plan_info.ToVector());
3940
}

csrc/batch_decode_mla_pybind.cu

+9-9
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,20 @@
11
#include "mla_config.inc"
22
#include "pytorch_extension_utils.h"
33

4-
std::vector<int64_t> BatchDecodeWithPagedKVCachePlanMLA(
4+
at::Tensor BatchDecodeWithPagedKVCachePlanMLA(
55
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
6-
at::Tensor page_locked_int_workspace_buffer, at::Tensor indptr, unsigned int batch_size,
7-
unsigned int num_qo_heads, unsigned int page_size, bool enable_cuda_graph, int64_t cuda_stream);
6+
at::Tensor page_locked_int_workspace_buffer, at::Tensor indptr, int64_t batch_size,
7+
int64_t num_qo_heads, int64_t page_size, bool enable_cuda_graph, int64_t cuda_stream);
88

99
void BatchDecodeWithPagedKVCacheRunMLA(
1010
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
11-
std::vector<int64_t> plan_info_vec, at::Tensor q_nope, at::Tensor q_pe,
11+
at::Tensor plan_info_vec, at::Tensor q_nope, at::Tensor q_pe,
1212
at::Tensor paged_ckv_cache, at::Tensor paged_kpe_cache, at::Tensor paged_kv_indptr,
13-
at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, at::Tensor o, float sm_scale,
14-
int window_left, float logits_soft_cap, float rope_scale, float rope_theta,
13+
at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, at::Tensor o, double sm_scale,
14+
int64_t window_left, double logits_soft_cap, double rope_scale, double rope_theta,
1515
std::optional<at::Tensor> maybe_lse, int64_t cuda_stream);
1616

17-
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
18-
m.def("plan", &BatchDecodeWithPagedKVCachePlanMLA);
19-
m.def("run", &BatchDecodeWithPagedKVCacheRunMLA);
17+
TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) {
18+
m.def("plan", BatchDecodeWithPagedKVCachePlanMLA);
19+
m.def("run", BatchDecodeWithPagedKVCacheRunMLA);
2020
}

csrc/batch_decode_mla_run.cu

+5-4
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,19 @@
44

55
#include "mla_config.inc"
66
#include "pytorch_extension_utils.h"
7+
#include "pytorch_conversion_utils.h"
78

89
using namespace flashinfer;
910

1011
void BatchDecodeWithPagedKVCacheRunMLA(
1112
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
12-
std::vector<int64_t> plan_info_vec, at::Tensor q_nope, at::Tensor q_pe,
13+
at::Tensor plan_info_vec, at::Tensor q_nope, at::Tensor q_pe,
1314
at::Tensor paged_ckv_cache, at::Tensor paged_kpe_cache, at::Tensor paged_kv_indptr,
14-
at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, at::Tensor o, float sm_scale,
15-
int window_left, float logits_soft_cap, float rope_scale, float rope_theta,
15+
at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, at::Tensor o, double sm_scale,
16+
int64_t window_left, double logits_soft_cap, double rope_scale, double rope_theta,
1617
std::optional<at::Tensor> maybe_lse, int64_t cuda_stream) {
1718
DecodePlanInfo plan_info;
18-
plan_info.FromVector(plan_info_vec);
19+
plan_info.FromVector(tensor_to_vec(plan_info_vec));
1920

2021
auto device = q_nope.device();
2122
int64_t batch_size = q_nope.size(0);

csrc/batch_mla_plan.cu

+8-8
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,17 @@
1717
#include <optional>
1818

1919
#include "batch_mla_config.inc"
20+
#include "pytorch_conversion_utils.h"
2021
#include "pytorch_extension_utils.h"
2122

2223
using namespace flashinfer;
2324

24-
std::vector<int64_t> BatchMLAPagedAttentionPlan(at::Tensor float_workspace_buffer,
25-
at::Tensor int_workspace_buffer,
26-
at::Tensor page_locked_int_workspace_buffer,
27-
at::Tensor qo_indptr, at::Tensor kv_indptr,
28-
at::Tensor kv_len, unsigned int num_heads,
29-
unsigned int head_dim_o, bool causal,
30-
int64_t cuda_stream) {
25+
at::Tensor BatchMLAPagedAttentionPlan(at::Tensor float_workspace_buffer,
26+
at::Tensor int_workspace_buffer,
27+
at::Tensor page_locked_int_workspace_buffer,
28+
at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor kv_len,
29+
int64_t num_heads, int64_t head_dim_o, bool causal,
30+
int64_t cuda_stream) {
3131
size_t float_workspace_size_in_bytes =
3232
float_workspace_buffer.size(0) * float_workspace_buffer.element_size();
3333
size_t int_workspace_size_in_bytes =
@@ -47,5 +47,5 @@ std::vector<int64_t> BatchMLAPagedAttentionPlan(at::Tensor float_workspace_buffe
4747

4848
TORCH_CHECK(status == cudaSuccess, "Failed to plan MLA, error: ", cudaGetErrorString(status));
4949

50-
return plan_info.ToVector();
50+
return vec_to_tensor(plan_info.ToVector());
5151
}

csrc/batch_mla_pybind.cu

+14-15
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,21 @@
1616
#include "batch_mla_config.inc"
1717
#include "pytorch_extension_utils.h"
1818

19-
std::vector<int64_t> BatchMLAPagedAttentionPlan(at::Tensor float_workspace_buffer,
20-
at::Tensor int_workspace_buffer,
21-
at::Tensor page_locked_int_workspace_buffer,
22-
at::Tensor qo_indptr, at::Tensor kv_indptr,
23-
at::Tensor kv_len, unsigned int num_heads,
24-
unsigned int head_dim_o, bool causal,
25-
int64_t cuda_stream);
19+
at::Tensor BatchMLAPagedAttentionPlan(at::Tensor float_workspace_buffer,
20+
at::Tensor int_workspace_buffer,
21+
at::Tensor page_locked_int_workspace_buffer,
22+
at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor kv_len,
23+
int64_t num_heads, int64_t head_dim_o, bool causal,
24+
int64_t cuda_stream);
2625

2726
void BatchMLAPagedAttentionRun(at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
28-
std::vector<int64_t> plan_info_vec, at::Tensor q_nope,
29-
at::Tensor q_pe, at::Tensor ckv_cache, at::Tensor kpe_cache,
30-
at::Tensor kv_indices, at::Tensor o,
31-
std::optional<at::Tensor> maybe_lse, int mask_mode_code,
32-
int num_heads, int page_size, float sm_scale, int64_t cuda_stream);
27+
at::Tensor plan_info_vec, at::Tensor q_nope, at::Tensor q_pe,
28+
at::Tensor ckv_cache, at::Tensor kpe_cache, at::Tensor kv_indices,
29+
at::Tensor o, std::optional<at::Tensor> maybe_lse,
30+
int64_t mask_mode_code, int64_t num_heads, int64_t page_size,
31+
double sm_scale, int64_t cuda_stream);
3332

34-
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
35-
m.def("plan", &BatchMLAPagedAttentionPlan, "Batch MLA Page Attention Plan");
36-
m.def("run", &BatchMLAPagedAttentionRun, "Batch MLA Page Attention Run");
33+
TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) {
34+
m.def("plan", &BatchMLAPagedAttentionPlan);
35+
m.def("run", &BatchMLAPagedAttentionRun);
3736
}

csrc/batch_mla_run.cu

+7-8
Original file line numberDiff line numberDiff line change
@@ -13,30 +13,29 @@
1313
* See the License for the specific language governing permissions and
1414
* limitations under the License.
1515
*/
16-
#include <driver_types.h>
17-
1816
#include <flashinfer/attention/mla_fa2.cuh>
1917
#include <flashinfer/attention/scheduler.cuh>
2018
#include <flashinfer/fastdiv.cuh>
2119
#include <optional>
2220

2321
#include "batch_mla_config.inc"
22+
#include "pytorch_conversion_utils.h"
2423
#include "pytorch_extension_utils.h"
2524

2625
using namespace flashinfer;
2726

2827
void BatchMLAPagedAttentionRun(at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
29-
std::vector<int64_t> plan_info_vec, at::Tensor q_nope,
30-
at::Tensor q_pe, at::Tensor ckv_cache, at::Tensor kpe_cache,
31-
at::Tensor kv_indices, at::Tensor o,
32-
std::optional<at::Tensor> maybe_lse, int mask_mode_code,
33-
int num_heads, int page_size, float sm_scale, int64_t cuda_stream) {
28+
at::Tensor plan_info_vec, at::Tensor q_nope, at::Tensor q_pe,
29+
at::Tensor ckv_cache, at::Tensor kpe_cache, at::Tensor kv_indices,
30+
at::Tensor o, std::optional<at::Tensor> maybe_lse,
31+
int64_t mask_mode_code, int64_t num_heads, int64_t page_size,
32+
double sm_scale, int64_t cuda_stream) {
3433
// q_nope: [n, num_heads, head_dim_ckv]
3534
// q_pe: [n, num_heads, head_dim_kpe]
3635
// ckv_cache: [num_pages, page_size, head_dim_ckv]
3736
// kpe_cache: [num_pages, page_size, head_dim_kpe]
3837
MLAPlanInfo plan_info;
39-
plan_info.FromVector(plan_info_vec);
38+
plan_info.FromVector(tensor_to_vec(plan_info_vec));
4039

4140
auto device = q_nope.device();
4241

csrc/batch_prefill.cu

+13-12
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
#include "batch_prefill_config.inc"
2222
#include "pytorch_extension_utils.h"
23+
#include "pytorch_conversion_utils.h"
2324

2425
namespace flashinfer {
2526

@@ -39,12 +40,12 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched(Params params, typename Para
3940

4041
using namespace flashinfer;
4142

42-
std::vector<int64_t> BatchPrefillWithKVCachePlan(
43+
at::Tensor BatchPrefillWithKVCachePlan(
4344
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
4445
at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr,
45-
at::Tensor kv_len_arr, unsigned total_num_rows, unsigned int batch_size,
46-
unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int page_size,
47-
bool enable_cuda_graph, unsigned int head_dim_qk, unsigned int head_dim_vo, bool causal,
46+
at::Tensor kv_len_arr, int64_t total_num_rows, int64_t batch_size,
47+
int64_t num_qo_heads, int64_t num_kv_heads, int64_t page_size,
48+
bool enable_cuda_graph, int64_t head_dim_qk, int64_t head_dim_vo, bool causal,
4849
int64_t cuda_stream) {
4950
size_t float_workspace_size_in_bytes =
5051
float_workspace_buffer.size(0) * float_workspace_buffer.element_size();
@@ -64,17 +65,17 @@ std::vector<int64_t> BatchPrefillWithKVCachePlan(
6465
TORCH_CHECK(status == cudaSuccess,
6566
"Failed to plan prefill with error: ", cudaGetErrorString(status));
6667

67-
return plan_info.ToVector();
68+
return vec_to_tensor(plan_info.ToVector());
6869
}
6970

7071
void BatchPrefillWithRaggedKVCacheRun(
7172
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
72-
std::vector<int64_t> plan_info_vec, at::Tensor q, at::Tensor k, at::Tensor v,
73+
at::Tensor plan_info_vec, at::Tensor q, at::Tensor k, at::Tensor v,
7374
at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor o, std::optional<at::Tensor> maybe_lse,
74-
unsigned int mask_mode_code, unsigned int layout, int32_t window_left ADDITIONAL_FUNC_PARAMS,
75+
int64_t mask_mode_code, int64_t layout, int64_t window_left ADDITIONAL_FUNC_PARAMS,
7576
int64_t cuda_stream) {
7677
PrefillPlanInfo plan_info;
77-
plan_info.FromVector(plan_info_vec);
78+
plan_info.FromVector(tensor_to_vec(plan_info_vec));
7879
QKVLayout kv_layout = static_cast<QKVLayout>(layout);
7980

8081
int64_t num_qo_heads = q.size(1);
@@ -194,13 +195,13 @@ void BatchPrefillWithRaggedKVCacheRun(
194195

195196
void BatchPrefillWithPagedKVCacheRun(
196197
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
197-
std::vector<int64_t> plan_info_vec, at::Tensor q, at::Tensor paged_k_cache,
198+
at::Tensor plan_info_vec, at::Tensor q, at::Tensor paged_k_cache,
198199
at::Tensor paged_v_cache, at::Tensor qo_indptr, at::Tensor paged_kv_indptr,
199200
at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, at::Tensor o,
200-
std::optional<at::Tensor> maybe_lse, unsigned int mask_mode_code, unsigned int layout,
201-
int32_t window_left ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream) {
201+
std::optional<at::Tensor> maybe_lse, int64_t mask_mode_code, int64_t layout,
202+
int64_t window_left ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream) {
202203
PrefillPlanInfo plan_info;
203-
plan_info.FromVector(plan_info_vec);
204+
plan_info.FromVector(tensor_to_vec(plan_info_vec));
204205
QKVLayout kv_layout = static_cast<QKVLayout>(layout);
205206
auto device = q.device();
206207
int64_t batch_size = paged_kv_indptr.size(0) - 1;

0 commit comments

Comments
 (0)