21
21
using namespace flashinfer ;
22
22
23
23
void BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward (
24
- torch::Tensor workspace_buffer, torch::Tensor qo_indptr, torch::Tensor paged_kv_indptr,
25
- unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads,
26
- unsigned int head_dim, unsigned int page_size, torch::Tensor empty_q_data) {
27
- CHECK_INPUT (workspace_buffer);
24
+ torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer,
25
+ torch::Tensor qo_indptr, torch::Tensor paged_kv_indptr, unsigned int batch_size,
26
+ unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim,
27
+ unsigned int page_size, torch::Tensor empty_q_data) {
28
+ CHECK_INPUT (float_workspace_buffer);
29
+ CHECK_INPUT (int_workspace_buffer);
28
30
// NOTE(Zihao): not necessary to be a CUDA tensor
29
31
CHECK_CONTIGUOUS (qo_indptr);
30
32
CHECK_CONTIGUOUS (paged_kv_indptr);
31
33
CHECK_GQA_HEAD_DIVISIBLE (num_qo_heads, num_kv_heads);
32
34
CHECK_DIM (1 , qo_indptr);
33
35
CHECK_DIM (1 , paged_kv_indptr);
34
- CHECK_DIM (1 , workspace_buffer);
36
+ CHECK_DIM (1 , float_workspace_buffer);
37
+ CHECK_DIM (1 , int_workspace_buffer);
35
38
CHECK_EQ (qo_indptr.size (0 ), batch_size + 1 );
36
39
CHECK_EQ (paged_kv_indptr.size (0 ), batch_size + 1 );
37
40
qo_indptr = qo_indptr.to (torch::dtype (torch::kInt32 ).device (torch::kCPU ));
38
41
paged_kv_indptr = paged_kv_indptr.to (torch::dtype (torch::kInt32 ).device (torch::kCPU ));
39
- auto device = workspace_buffer.device ();
40
- size_t workspace_size_in_bytes = workspace_buffer.size (0 ) * workspace_buffer.element_size ();
42
+ auto device = float_workspace_buffer.device ();
43
+ size_t float_workspace_size_in_bytes =
44
+ float_workspace_buffer.size (0 ) * float_workspace_buffer.element_size ();
45
+ size_t int_workspace_size_in_bytes =
46
+ int_workspace_buffer.size (0 ) * int_workspace_buffer.element_size ();
41
47
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream (device.index ());
42
48
handler_->SetCUDAStream (torch_current_stream);
43
49
44
50
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16 (empty_q_data.scalar_type (), q_type, [&] {
45
51
cudaError_t status = handler_->BeginForward <q_type, int32_t >(
46
- static_cast <void *>(workspace_buffer.data_ptr ()), workspace_size_in_bytes,
52
+ static_cast <void *>(float_workspace_buffer.data_ptr ()), float_workspace_size_in_bytes,
53
+ static_cast <void *>(int_workspace_buffer.data_ptr ()), int_workspace_size_in_bytes,
47
54
static_cast <int32_t *>(qo_indptr.data_ptr ()),
48
55
static_cast <int32_t *>(paged_kv_indptr.data_ptr ()), batch_size, num_qo_heads, num_kv_heads,
49
56
head_dim, page_size);
@@ -56,8 +63,8 @@ void BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward(
56
63
void BatchPrefillWithPagedKVCachePyTorchWrapper::EndForward () { handler_->EndForward (); }
57
64
58
65
void BatchPrefillWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize (
59
- unsigned int max_workspace_size_in_bytes ) {
60
- handler_->UpdatePageLockedBufferSize (max_workspace_size_in_bytes );
66
+ unsigned int int_workspace_size_in_bytes ) {
67
+ handler_->UpdatePageLockedBufferSize (int_workspace_size_in_bytes );
61
68
}
62
69
63
70
std::vector<torch::Tensor> BatchPrefillWithPagedKVCachePyTorchWrapper::Forward (
@@ -446,28 +453,35 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCu
446
453
}
447
454
448
455
void BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward (
449
- torch::Tensor workspace_buffer, torch::Tensor qo_indptr, torch::Tensor kv_indptr,
450
- unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads,
451
- unsigned int head_dim, torch::Tensor empty_q_data) {
452
- CHECK_INPUT (workspace_buffer);
456
+ torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer,
457
+ torch::Tensor qo_indptr, torch::Tensor kv_indptr, unsigned int batch_size,
458
+ unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim,
459
+ torch::Tensor empty_q_data) {
460
+ CHECK_INPUT (float_workspace_buffer);
461
+ CHECK_INPUT (int_workspace_buffer);
453
462
// NOTE(Zihao): not necessary to be a CUDA tensor
454
463
CHECK_CONTIGUOUS (qo_indptr);
455
464
CHECK_GQA_HEAD_DIVISIBLE (num_qo_heads, num_kv_heads);
456
465
CHECK_DIM (1 , qo_indptr);
457
466
CHECK_DIM (1 , kv_indptr);
458
- CHECK_DIM (1 , workspace_buffer);
467
+ CHECK_DIM (1 , float_workspace_buffer);
468
+ CHECK_DIM (1 , int_workspace_buffer);
459
469
CHECK_EQ (qo_indptr.size (0 ), batch_size + 1 );
460
470
CHECK_EQ (kv_indptr.size (0 ), batch_size + 1 );
461
471
qo_indptr = qo_indptr.to (torch::dtype (torch::kInt32 ).device (torch::kCPU ));
462
472
kv_indptr = kv_indptr.to (torch::dtype (torch::kInt32 ).device (torch::kCPU ));
463
- size_t workspace_size_in_bytes = workspace_buffer.size (0 ) * workspace_buffer.element_size ();
464
- auto device = workspace_buffer.device ();
473
+ size_t float_workspace_size_in_bytes =
474
+ float_workspace_buffer.size (0 ) * float_workspace_buffer.element_size ();
475
+ size_t int_workspace_size_in_bytes =
476
+ int_workspace_buffer.size (0 ) * int_workspace_buffer.element_size ();
477
+ auto device = float_workspace_buffer.device ();
465
478
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream (device.index ());
466
479
handler_->SetCUDAStream (torch_current_stream);
467
480
468
481
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16 (empty_q_data.scalar_type (), q_type, [&] {
469
482
cudaError_t status = handler_->BeginForward <q_type, int32_t >(
470
- static_cast <void *>(workspace_buffer.data_ptr ()), workspace_size_in_bytes,
483
+ static_cast <void *>(float_workspace_buffer.data_ptr ()), float_workspace_size_in_bytes,
484
+ static_cast <void *>(int_workspace_buffer.data_ptr ()), int_workspace_size_in_bytes,
471
485
static_cast <int32_t *>(qo_indptr.data_ptr ()), static_cast <int32_t *>(kv_indptr.data_ptr ()),
472
486
batch_size, num_qo_heads, num_kv_heads, head_dim,
473
487
/* page_size=*/ 1 );
@@ -480,8 +494,8 @@ void BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward(
480
494
void BatchPrefillWithRaggedKVCachePyTorchWrapper::EndForward () { handler_->EndForward (); }
481
495
482
496
void BatchPrefillWithRaggedKVCachePyTorchWrapper::UpdatePageLockedBufferSize (
483
- unsigned int max_workspace_size_in_bytes ) {
484
- handler_->UpdatePageLockedBufferSize (max_workspace_size_in_bytes );
497
+ unsigned int int_workspace_size_in_bytes ) {
498
+ handler_->UpdatePageLockedBufferSize (int_workspace_size_in_bytes );
485
499
}
486
500
487
501
std::vector<torch::Tensor> BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward (
0 commit comments