20
20
21
21
#include " batch_prefill_config.inc"
22
22
#include " pytorch_extension_utils.h"
23
+ #include " pytorch_conversion_utils.h"
23
24
24
25
namespace flashinfer {
25
26
@@ -39,12 +40,12 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched(Params params, typename Para
39
40
40
41
using namespace flashinfer ;
41
42
42
- std::vector< int64_t > BatchPrefillWithKVCachePlan (
43
+ at::Tensor BatchPrefillWithKVCachePlan (
43
44
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
44
45
at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr,
45
- at::Tensor kv_len_arr, unsigned total_num_rows, unsigned int batch_size,
46
- unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int page_size,
47
- bool enable_cuda_graph, unsigned int head_dim_qk, unsigned int head_dim_vo, bool causal,
46
+ at::Tensor kv_len_arr, int64_t total_num_rows, int64_t batch_size,
47
+ int64_t num_qo_heads, int64_t num_kv_heads, int64_t page_size,
48
+ bool enable_cuda_graph, int64_t head_dim_qk, int64_t head_dim_vo, bool causal,
48
49
int64_t cuda_stream) {
49
50
size_t float_workspace_size_in_bytes =
50
51
float_workspace_buffer.size (0 ) * float_workspace_buffer.element_size ();
@@ -64,17 +65,17 @@ std::vector<int64_t> BatchPrefillWithKVCachePlan(
64
65
TORCH_CHECK (status == cudaSuccess,
65
66
" Failed to plan prefill with error: " , cudaGetErrorString (status));
66
67
67
- return plan_info.ToVector ();
68
+ return vec_to_tensor ( plan_info.ToVector () );
68
69
}
69
70
70
71
void BatchPrefillWithRaggedKVCacheRun (
71
72
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
72
- std::vector< int64_t > plan_info_vec, at::Tensor q, at::Tensor k, at::Tensor v,
73
+ at::Tensor plan_info_vec, at::Tensor q, at::Tensor k, at::Tensor v,
73
74
at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor o, std::optional<at::Tensor> maybe_lse,
74
- unsigned int mask_mode_code, unsigned int layout, int32_t window_left ADDITIONAL_FUNC_PARAMS,
75
+ int64_t mask_mode_code, int64_t layout, int64_t window_left ADDITIONAL_FUNC_PARAMS,
75
76
int64_t cuda_stream) {
76
77
PrefillPlanInfo plan_info;
77
- plan_info.FromVector (plan_info_vec);
78
+ plan_info.FromVector (tensor_to_vec ( plan_info_vec) );
78
79
QKVLayout kv_layout = static_cast <QKVLayout>(layout);
79
80
80
81
int64_t num_qo_heads = q.size (1 );
@@ -194,13 +195,13 @@ void BatchPrefillWithRaggedKVCacheRun(
194
195
195
196
void BatchPrefillWithPagedKVCacheRun (
196
197
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
197
- std::vector< int64_t > plan_info_vec, at::Tensor q, at::Tensor paged_k_cache,
198
+ at::Tensor plan_info_vec, at::Tensor q, at::Tensor paged_k_cache,
198
199
at::Tensor paged_v_cache, at::Tensor qo_indptr, at::Tensor paged_kv_indptr,
199
200
at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, at::Tensor o,
200
- std::optional<at::Tensor> maybe_lse, unsigned int mask_mode_code, unsigned int layout,
201
- int32_t window_left ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream) {
201
+ std::optional<at::Tensor> maybe_lse, int64_t mask_mode_code, int64_t layout,
202
+ int64_t window_left ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream) {
202
203
PrefillPlanInfo plan_info;
203
- plan_info.FromVector (plan_info_vec);
204
+ plan_info.FromVector (tensor_to_vec ( plan_info_vec) );
204
205
QKVLayout kv_layout = static_cast <QKVLayout>(layout);
205
206
auto device = q.device ();
206
207
int64_t batch_size = paged_kv_indptr.size (0 ) - 1 ;
0 commit comments