29
29
namespace flashinfer {
30
30
31
31
template <uint32_t HEAD_DIM, MaskMode MASK_MODE, bool LEFT_SLINDING_WINDOW,
32
- typename AttentionVariant , typename DTypeQ , typename DTypeKV, typename DTypeO ,
33
- typename IdType>
32
+ bool SAME_SCHEDULE_FOR_ALL_HEADS , typename AttentionVariant , typename DTypeQ ,
33
+ typename DTypeKV, typename DTypeO, typename IdType>
34
34
cudaError_t BatchPrefillWithRaggedKVCacheDispatched (
35
35
BatchPrefillRaggedParams<DTypeQ, DTypeKV, DTypeO, IdType>& params, cudaStream_t stream);
36
36
37
37
template <uint32_t HEAD_DIM, MaskMode MASK_MODE, bool LEFT_SLINDING_WINDOW,
38
- typename AttentionVariant , typename DTypeQ , typename DTypeKV, typename DTypeO ,
39
- typename IdType>
38
+ bool SAME_SCHEDULE_FOR_ALL_HEADS , typename AttentionVariant , typename DTypeQ ,
39
+ typename DTypeKV, typename DTypeO, typename IdType>
40
40
cudaError_t BatchPrefillWithPagedKVCacheDispatched (
41
41
BatchPrefillPagedParams<DTypeQ, DTypeKV, DTypeO, IdType>& params, cudaStream_t stream);
42
42
@@ -47,9 +47,9 @@ using namespace flashinfer;
47
47
std::vector<int64_t > BatchPrefillWithKVCacheSM90Plan (
48
48
unsigned int head_dim, bool causal, at::Tensor float_workspace_buffer,
49
49
at::Tensor int_workspace_buffer, at::Tensor page_locked_int_workspace_buffer,
50
- at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor kv_len_arr, unsigned int batch_size ,
51
- unsigned int num_qo_heads , unsigned int num_kv_heads , unsigned int page_size ,
52
- bool enable_cuda_graph, int64_t cuda_stream) {
50
+ at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor kv_len_arr, unsigned int total_num_rows ,
51
+ unsigned int batch_size , unsigned int num_qo_heads , unsigned int num_kv_heads ,
52
+ unsigned int page_size, bool enable_cuda_graph, int64_t cuda_stream) {
53
53
size_t float_workspace_size_in_bytes =
54
54
float_workspace_buffer.size (0 ) * float_workspace_buffer.element_size ();
55
55
size_t int_workspace_size_in_bytes =
@@ -61,12 +61,13 @@ std::vector<int64_t> BatchPrefillWithKVCacheSM90Plan(
61
61
62
62
cudaStream_t stream = reinterpret_cast <cudaStream_t>(cuda_stream);
63
63
64
- cudaError_t status = PrefillSM90Plan (
65
- float_workspace_buffer.data_ptr (), float_workspace_size_in_bytes,
66
- int_workspace_buffer.data_ptr (), page_locked_int_workspace_buffer.data_ptr (),
67
- int_workspace_size_in_bytes, plan_info, qo_indptr.data_ptr <IdType>(),
68
- kv_indptr.data_ptr <IdType>(), kv_len_arr.data_ptr <IdType>(), batch_size, num_qo_heads,
69
- num_kv_heads, head_dim, page_size, causal, enable_cuda_graph, /* sizeof_dtype_o=*/ 2 , stream);
64
+ cudaError_t status =
65
+ PrefillSM90Plan (float_workspace_buffer.data_ptr (), float_workspace_size_in_bytes,
66
+ int_workspace_buffer.data_ptr (), page_locked_int_workspace_buffer.data_ptr (),
67
+ int_workspace_size_in_bytes, plan_info, qo_indptr.data_ptr <IdType>(),
68
+ kv_indptr.data_ptr <IdType>(), kv_len_arr.data_ptr <IdType>(), total_num_rows,
69
+ batch_size, num_qo_heads, num_kv_heads, head_dim, page_size, causal,
70
+ enable_cuda_graph, /* sizeof_dtype_o=*/ 2 , stream);
70
71
71
72
TORCH_CHECK (status == cudaSuccess,
72
73
" PrefillSM90Plan failed with error: " , cudaGetErrorString (status));
@@ -151,19 +152,23 @@ void BatchPrefillWithRaggedKVCacheSM90Run(
151
152
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.head_indices_offset );
152
153
params.work_indptr = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.work_indptr_offset );
153
154
155
+ bool same_schedule_for_all_heads = plan_info.same_schedule_for_all_heads ;
156
+
154
157
return DISPATCH_head_dim (head_dim, HEAD_DIM, [&] {
155
158
return DISPATCH_mask_mode (mask_mode, MASK_MODE, [&] {
156
159
return DISPATCH_BOOL (use_logits_soft_cap, USE_LOGITS_SOFT_CAP, [&] {
157
160
return DISPATCH_BOOL (use_swa, USE_SWA, [&] {
158
- using AttentionVariant =
159
- std::conditional_t <USE_LOGITS_SOFT_CAP, LogitsSoftCap, StandardAttention>;
160
- cudaError_t status =
161
- BatchPrefillWithRaggedKVCacheDispatched<HEAD_DIM, MASK_MODE, USE_SWA,
162
- AttentionVariant>(params, stream);
163
- TORCH_CHECK (status == cudaSuccess,
164
- " BatchPrefillWithRaggedKVCacheSM90Run failed with error: " ,
165
- cudaGetErrorString (status));
166
- return true ;
161
+ return DISPATCH_BOOL (same_schedule_for_all_heads, SAME_SCHEDULER_FOR_ALL_HEADS, [&] {
162
+ using AttentionVariant =
163
+ std::conditional_t <USE_LOGITS_SOFT_CAP, LogitsSoftCap, StandardAttention>;
164
+ cudaError_t status = BatchPrefillWithRaggedKVCacheDispatched<
165
+ HEAD_DIM, MASK_MODE, USE_SWA, SAME_SCHEDULER_FOR_ALL_HEADS, AttentionVariant>(
166
+ params, stream);
167
+ TORCH_CHECK (status == cudaSuccess,
168
+ " BatchPrefillWithRaggedKVCacheSM90Run failed with error: " ,
169
+ cudaGetErrorString (status));
170
+ return true ;
171
+ });
167
172
});
168
173
});
169
174
});
@@ -259,20 +264,23 @@ void BatchPrefillWithPagedKVCacheSM90Run(
259
264
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.head_indices_offset );
260
265
params.work_indptr = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.work_indptr_offset );
261
266
params.kv_indices = static_cast <IdType*>(paged_kv_indices.data_ptr ());
267
+ bool same_schedule_for_all_heads = plan_info.same_schedule_for_all_heads ;
262
268
263
269
return DISPATCH_head_dim (head_dim, HEAD_DIM, [&] {
264
270
return DISPATCH_mask_mode (mask_mode, MASK_MODE, [&] {
265
271
return DISPATCH_BOOL (use_logits_soft_cap, USE_LOGITS_SOFT_CAP, [&] {
266
272
return DISPATCH_BOOL (use_swa, USE_SWA, [&] {
267
- using AttentionVariant =
268
- std::conditional_t <USE_LOGITS_SOFT_CAP, LogitsSoftCap, StandardAttention>;
269
- cudaError_t status =
270
- BatchPrefillWithPagedKVCacheDispatched<HEAD_DIM, MASK_MODE, USE_SWA,
271
- AttentionVariant>(params, stream);
272
- TORCH_CHECK (status == cudaSuccess,
273
- " BatchPrefillWithPagedKVCacheSM90Run failed with error: " ,
274
- cudaGetErrorString (status));
275
- return true ;
273
+ return DISPATCH_BOOL (same_schedule_for_all_heads, SAME_SCHEDULER_FOR_ALL_HEADS, [&] {
274
+ using AttentionVariant =
275
+ std::conditional_t <USE_LOGITS_SOFT_CAP, LogitsSoftCap, StandardAttention>;
276
+ cudaError_t status = BatchPrefillWithPagedKVCacheDispatched<
277
+ HEAD_DIM, MASK_MODE, USE_SWA, SAME_SCHEDULER_FOR_ALL_HEADS, AttentionVariant>(
278
+ params, stream);
279
+ TORCH_CHECK (status == cudaSuccess,
280
+ " BatchPrefillWithPagedKVCacheSM90Run failed with error: " ,
281
+ cudaGetErrorString (status));
282
+ return true ;
283
+ });
276
284
});
277
285
});
278
286
});
0 commit comments