Skip to content

Commit a7ee566

Browse files
authored
feat: decouple float and int workspace buffer (#442)
Before this PR, flashinfer coupled float and int buffers in a single workspace buffer, and different wrappers cannot share the same buffers. This PR decouples float and int workspace buffer. The float workspace buffer (large) can be shared in multiple wrappers, and the int buffer (small) is unique for each wrapper. This PR can save GPU memory when multiple wrappers are created (decode, prefill paged, prefill ragged) or cascade inference.
1 parent 3fff008 commit a7ee566

19 files changed

+467
-275
lines changed

Diff for: include/flashinfer/attention/handler.cuh

+78-71
Large diffs are not rendered by default.

Diff for: python/csrc/activation.cu

+1-2
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,7 @@ void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input) {
5151
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
5252
uint32_t vec_size = 16 / sizeof(c_type);
5353
dim3 block(std::min(d / vec_size, 1024U));
54-
flashinfer::activation::act_and_mul_kernel<c_type,
55-
flashinfer::activation::gelu_tanh_kernel>
54+
flashinfer::activation::act_and_mul_kernel<c_type, flashinfer::activation::gelu_tanh_kernel>
5655
<<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()),
5756
static_cast<c_type*>(input.data_ptr()), d);
5857

Diff for: python/csrc/batch_decode.cu

+24-14
Original file line numberDiff line numberDiff line change
@@ -21,22 +21,28 @@
2121
using namespace flashinfer;
2222

2323
void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward(
24-
torch::Tensor workspace_buffer, torch::Tensor indptr, torch::Tensor last_page_len,
25-
unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads,
26-
unsigned int head_dim, unsigned int page_size, unsigned int pos_encoding_mode,
27-
float logits_soft_cap, torch::Tensor empty_q_data, torch::Tensor empty_kv_data) {
28-
CHECK_INPUT(workspace_buffer);
24+
torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, torch::Tensor indptr,
25+
torch::Tensor last_page_len, unsigned int batch_size, unsigned int num_qo_heads,
26+
unsigned int num_kv_heads, unsigned int head_dim, unsigned int page_size,
27+
unsigned int pos_encoding_mode, float logits_soft_cap, torch::Tensor empty_q_data,
28+
torch::Tensor empty_kv_data) {
29+
CHECK_INPUT(float_workspace_buffer);
30+
CHECK_INPUT(int_workspace_buffer);
2931
// NOTE(zihao): not necessary to be CUDA tensor
3032
CHECK_CONTIGUOUS(indptr);
3133
CHECK_CONTIGUOUS(last_page_len);
3234
CHECK_DIM(1, indptr);
3335
CHECK_DIM(1, last_page_len);
34-
CHECK_DIM(1, workspace_buffer);
36+
CHECK_DIM(1, float_workspace_buffer);
37+
CHECK_DIM(1, int_workspace_buffer);
3538
CHECK_EQ(indptr.scalar_type(), torch::kInt32);
3639
CHECK_EQ(indptr.scalar_type(), torch::kInt32);
3740
CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads);
38-
size_t workspace_size_in_bytes = workspace_buffer.size(0) * workspace_buffer.element_size();
39-
auto device = workspace_buffer.device();
41+
size_t float_workspace_size_in_bytes =
42+
float_workspace_buffer.size(0) * float_workspace_buffer.element_size();
43+
size_t int_workspace_size_in_bytes =
44+
int_workspace_buffer.size(0) * int_workspace_buffer.element_size();
45+
auto device = float_workspace_buffer.device();
4046
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
4147
handler_->SetCUDAStream(torch_current_stream);
4248
indptr = indptr.to(torch::kCPU);
@@ -59,8 +65,10 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward(
5965
handler_->BeginForwardDispatched<HEAD_DIM, PageStorage::kIndices,
6066
LOGITS_POST_HOOK, POS_ENCODING_MODE, qkv_type,
6167
qkv_type, qkv_type, int32_t>(
62-
static_cast<void*>(workspace_buffer.data_ptr()), workspace_size_in_bytes,
63-
static_cast<int32_t*>(indptr.data_ptr()),
68+
static_cast<void*>(float_workspace_buffer.data_ptr()),
69+
float_workspace_size_in_bytes,
70+
static_cast<void*>(int_workspace_buffer.data_ptr()),
71+
int_workspace_size_in_bytes, static_cast<int32_t*>(indptr.data_ptr()),
6472
static_cast<int32_t*>(last_page_len.data_ptr()), batch_size, num_qo_heads,
6573
num_kv_heads, page_size);
6674
TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPagedKVCache failed with error ",
@@ -81,8 +89,10 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward(
8189
handler_->BeginForwardDispatched<HEAD_DIM, PageStorage::kIndices,
8290
LOGITS_POST_HOOK, POS_ENCODING_MODE, q_type,
8391
kv_type, q_type, int32_t>(
84-
static_cast<void*>(workspace_buffer.data_ptr()), workspace_size_in_bytes,
85-
static_cast<int32_t*>(indptr.data_ptr()),
92+
static_cast<void*>(float_workspace_buffer.data_ptr()),
93+
float_workspace_size_in_bytes,
94+
static_cast<void*>(int_workspace_buffer.data_ptr()),
95+
int_workspace_size_in_bytes, static_cast<int32_t*>(indptr.data_ptr()),
8696
static_cast<int32_t*>(last_page_len.data_ptr()), batch_size, num_qo_heads,
8797
num_kv_heads, page_size);
8898
TORCH_CHECK(status == cudaSuccess,
@@ -100,8 +110,8 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward(
100110
void BatchDecodeWithPagedKVCachePyTorchWrapper::EndForward() { handler_->EndForward(); }
101111

102112
void BatchDecodeWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize(
103-
unsigned int max_workspace_size_in_bytes) {
104-
handler_->UpdatePageLockedBufferSize(max_workspace_size_in_bytes);
113+
unsigned int int_workspace_size_in_bytes) {
114+
handler_->UpdatePageLockedBufferSize(int_workspace_size_in_bytes);
105115
}
106116

107117
std::vector<torch::Tensor> BatchDecodeWithPagedKVCachePyTorchWrapper::Forward(

Diff for: python/csrc/batch_prefill.cu

+34-20
Original file line numberDiff line numberDiff line change
@@ -21,29 +21,36 @@
2121
using namespace flashinfer;
2222

2323
void BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward(
24-
torch::Tensor workspace_buffer, torch::Tensor qo_indptr, torch::Tensor paged_kv_indptr,
25-
unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads,
26-
unsigned int head_dim, unsigned int page_size, torch::Tensor empty_q_data) {
27-
CHECK_INPUT(workspace_buffer);
24+
torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer,
25+
torch::Tensor qo_indptr, torch::Tensor paged_kv_indptr, unsigned int batch_size,
26+
unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim,
27+
unsigned int page_size, torch::Tensor empty_q_data) {
28+
CHECK_INPUT(float_workspace_buffer);
29+
CHECK_INPUT(int_workspace_buffer);
2830
// NOTE(Zihao): not necessary to be a CUDA tensor
2931
CHECK_CONTIGUOUS(qo_indptr);
3032
CHECK_CONTIGUOUS(paged_kv_indptr);
3133
CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads);
3234
CHECK_DIM(1, qo_indptr);
3335
CHECK_DIM(1, paged_kv_indptr);
34-
CHECK_DIM(1, workspace_buffer);
36+
CHECK_DIM(1, float_workspace_buffer);
37+
CHECK_DIM(1, int_workspace_buffer);
3538
CHECK_EQ(qo_indptr.size(0), batch_size + 1);
3639
CHECK_EQ(paged_kv_indptr.size(0), batch_size + 1);
3740
qo_indptr = qo_indptr.to(torch::dtype(torch::kInt32).device(torch::kCPU));
3841
paged_kv_indptr = paged_kv_indptr.to(torch::dtype(torch::kInt32).device(torch::kCPU));
39-
auto device = workspace_buffer.device();
40-
size_t workspace_size_in_bytes = workspace_buffer.size(0) * workspace_buffer.element_size();
42+
auto device = float_workspace_buffer.device();
43+
size_t float_workspace_size_in_bytes =
44+
float_workspace_buffer.size(0) * float_workspace_buffer.element_size();
45+
size_t int_workspace_size_in_bytes =
46+
int_workspace_buffer.size(0) * int_workspace_buffer.element_size();
4147
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
4248
handler_->SetCUDAStream(torch_current_stream);
4349

4450
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(empty_q_data.scalar_type(), q_type, [&] {
4551
cudaError_t status = handler_->BeginForward<q_type, int32_t>(
46-
static_cast<void*>(workspace_buffer.data_ptr()), workspace_size_in_bytes,
52+
static_cast<void*>(float_workspace_buffer.data_ptr()), float_workspace_size_in_bytes,
53+
static_cast<void*>(int_workspace_buffer.data_ptr()), int_workspace_size_in_bytes,
4754
static_cast<int32_t*>(qo_indptr.data_ptr()),
4855
static_cast<int32_t*>(paged_kv_indptr.data_ptr()), batch_size, num_qo_heads, num_kv_heads,
4956
head_dim, page_size);
@@ -56,8 +63,8 @@ void BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward(
5663
void BatchPrefillWithPagedKVCachePyTorchWrapper::EndForward() { handler_->EndForward(); }
5764

5865
void BatchPrefillWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize(
59-
unsigned int max_workspace_size_in_bytes) {
60-
handler_->UpdatePageLockedBufferSize(max_workspace_size_in_bytes);
66+
unsigned int int_workspace_size_in_bytes) {
67+
handler_->UpdatePageLockedBufferSize(int_workspace_size_in_bytes);
6168
}
6269

6370
std::vector<torch::Tensor> BatchPrefillWithPagedKVCachePyTorchWrapper::Forward(
@@ -446,28 +453,35 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCu
446453
}
447454

448455
void BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward(
449-
torch::Tensor workspace_buffer, torch::Tensor qo_indptr, torch::Tensor kv_indptr,
450-
unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads,
451-
unsigned int head_dim, torch::Tensor empty_q_data) {
452-
CHECK_INPUT(workspace_buffer);
456+
torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer,
457+
torch::Tensor qo_indptr, torch::Tensor kv_indptr, unsigned int batch_size,
458+
unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim,
459+
torch::Tensor empty_q_data) {
460+
CHECK_INPUT(float_workspace_buffer);
461+
CHECK_INPUT(int_workspace_buffer);
453462
// NOTE(Zihao): not necessary to be a CUDA tensor
454463
CHECK_CONTIGUOUS(qo_indptr);
455464
CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads);
456465
CHECK_DIM(1, qo_indptr);
457466
CHECK_DIM(1, kv_indptr);
458-
CHECK_DIM(1, workspace_buffer);
467+
CHECK_DIM(1, float_workspace_buffer);
468+
CHECK_DIM(1, int_workspace_buffer);
459469
CHECK_EQ(qo_indptr.size(0), batch_size + 1);
460470
CHECK_EQ(kv_indptr.size(0), batch_size + 1);
461471
qo_indptr = qo_indptr.to(torch::dtype(torch::kInt32).device(torch::kCPU));
462472
kv_indptr = kv_indptr.to(torch::dtype(torch::kInt32).device(torch::kCPU));
463-
size_t workspace_size_in_bytes = workspace_buffer.size(0) * workspace_buffer.element_size();
464-
auto device = workspace_buffer.device();
473+
size_t float_workspace_size_in_bytes =
474+
float_workspace_buffer.size(0) * float_workspace_buffer.element_size();
475+
size_t int_workspace_size_in_bytes =
476+
int_workspace_buffer.size(0) * int_workspace_buffer.element_size();
477+
auto device = float_workspace_buffer.device();
465478
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
466479
handler_->SetCUDAStream(torch_current_stream);
467480

468481
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(empty_q_data.scalar_type(), q_type, [&] {
469482
cudaError_t status = handler_->BeginForward<q_type, int32_t>(
470-
static_cast<void*>(workspace_buffer.data_ptr()), workspace_size_in_bytes,
483+
static_cast<void*>(float_workspace_buffer.data_ptr()), float_workspace_size_in_bytes,
484+
static_cast<void*>(int_workspace_buffer.data_ptr()), int_workspace_size_in_bytes,
471485
static_cast<int32_t*>(qo_indptr.data_ptr()), static_cast<int32_t*>(kv_indptr.data_ptr()),
472486
batch_size, num_qo_heads, num_kv_heads, head_dim,
473487
/*page_size=*/1);
@@ -480,8 +494,8 @@ void BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward(
480494
void BatchPrefillWithRaggedKVCachePyTorchWrapper::EndForward() { handler_->EndForward(); }
481495

482496
void BatchPrefillWithRaggedKVCachePyTorchWrapper::UpdatePageLockedBufferSize(
483-
unsigned int max_workspace_size_in_bytes) {
484-
handler_->UpdatePageLockedBufferSize(max_workspace_size_in_bytes);
497+
unsigned int int_workspace_size_in_bytes) {
498+
handler_->UpdatePageLockedBufferSize(int_workspace_size_in_bytes);
485499
}
486500

487501
std::vector<torch::Tensor> BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward(

Diff for: python/csrc/flashinfer_ops_decode.h

+5-5
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,13 @@ torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torc
2828

2929
class BatchDecodeWithPagedKVCachePyTorchWrapper {
3030
public:
31-
void BeginForward(torch::Tensor workspace_buffer, torch::Tensor indptr,
32-
torch::Tensor last_page_len, unsigned int batch_size, unsigned int num_qo_heads,
33-
unsigned int num_kv_heads, unsigned int head_dim, unsigned int page_size,
34-
unsigned int pos_encoding_mode, float logits_soft_cap,
31+
void BeginForward(torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer,
32+
torch::Tensor indptr, torch::Tensor last_page_len, unsigned int batch_size,
33+
unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim,
34+
unsigned int page_size, unsigned int pos_encoding_mode, float logits_soft_cap,
3535
torch::Tensor empty_q_data, torch::Tensor empty_kv_data);
3636
void EndForward();
37-
void UpdatePageLockedBufferSize(uint32_t max_workspace_size_in_bytes);
37+
void UpdatePageLockedBufferSize(uint32_t int_workspace_size_in_bytes);
3838
bool IsCUDAGraphEnabled() const { return handler_->IsCUDAGraphEnabled(); }
3939
std::vector<torch::Tensor> Forward(torch::Tensor q, std::optional<torch::Tensor> paged_kv_cache,
4040
std::optional<torch::Tensor> paged_k_cache,

Diff for: python/csrc/flashinfer_ops_prefill.h

+8-7
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,13 @@ std::vector<torch::Tensor> single_prefill_with_kv_cache_custom_mask(
3434

3535
class BatchPrefillWithPagedKVCachePyTorchWrapper {
3636
public:
37-
void BeginForward(torch::Tensor workspace_buffer, torch::Tensor qo_indptr,
38-
torch::Tensor page_kv_indptr, unsigned int batch_size,
37+
void BeginForward(torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer,
38+
torch::Tensor qo_indptr, torch::Tensor page_kv_indptr, unsigned int batch_size,
3939
unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim,
4040
unsigned page_size, torch::Tensor empty_q_data);
4141
void EndForward();
4242
bool IsCUDAGraphEnabled() const { return handler_->IsCUDAGraphEnabled(); }
43-
void UpdatePageLockedBufferSize(uint32_t max_workspace_size_in_bytes);
43+
void UpdatePageLockedBufferSize(uint32_t int_workspace_size_in_bytes);
4444
std::vector<torch::Tensor> Forward(torch::Tensor q, torch::Tensor qo_indptr,
4545
std::optional<torch::Tensor> paged_kv_cache,
4646
std::optional<torch::Tensor> paged_k_cache,
@@ -69,12 +69,13 @@ class BatchPrefillWithPagedKVCachePyTorchWrapper {
6969

7070
class BatchPrefillWithRaggedKVCachePyTorchWrapper {
7171
public:
72-
void BeginForward(torch::Tensor workspace_buffer, torch::Tensor qo_indptr,
73-
torch::Tensor kv_indptr, unsigned int batch_size, unsigned int num_qo_heads,
74-
unsigned int num_kv_heads, unsigned int head_dim, torch::Tensor empty_q_data);
72+
void BeginForward(torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer,
73+
torch::Tensor qo_indptr, torch::Tensor kv_indptr, unsigned int batch_size,
74+
unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim,
75+
torch::Tensor empty_q_data);
7576
void EndForward();
7677
bool IsCUDAGraphEnabled() const { return handler_->IsCUDAGraphEnabled(); }
77-
void UpdatePageLockedBufferSize(uint32_t max_workspace_size_in_bytes);
78+
void UpdatePageLockedBufferSize(uint32_t int_workspace_size_in_bytes);
7879
std::vector<torch::Tensor> Forward(torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor k,
7980
torch::Tensor v, torch::Tensor kv_indptr, bool causal,
8081
unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction,

Diff for: python/flashinfer/cascade.py

+36-16
Original file line numberDiff line numberDiff line change
@@ -257,22 +257,32 @@ class BatchDecodeWithSharedPrefixPagedKVCacheWrapper:
257257
manages the lifecycle of these data structures.
258258
"""
259259

260-
def __init__(self, workspace_buffer: torch.Tensor, kv_layout: str = "NHD") -> None:
260+
def __init__(
261+
self, float_workspace_buffer: torch.Tensor, kv_layout: str = "NHD"
262+
) -> None:
261263
self._batch_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
262-
workspace_buffer, kv_layout
264+
float_workspace_buffer, kv_layout
263265
)
264266
self._kv_layout = kv_layout
265267

266-
def reset_workspace_buffer(self, new_workspace_buffer: torch.Tensor) -> None:
268+
def reset_workspace_buffer(
269+
self, float_workspace_buffer: torch.Tensor, int_workspace_buffer
270+
) -> None:
267271
r"""Reset the workspace buffer.
268272
269273
Parameters
270274
----------
271-
new_workspace_buffer : torch.Tensor
272-
The new workspace buffer, the device of the new workspace buffer should
275+
float_workspace_buffer : torch.Tensor
276+
The new float workspace buffer, the device of the new float workspace buffer should
277+
be the same as the device of the input tensors.
278+
279+
int_workspace_buffer : torch.Tensor
280+
The new int workspace buffer, the device of the new int workspace buffer should
273281
be the same as the device of the input tensors.
274282
"""
275-
self._batch_decode_wrapper.reset_workspace_buffer(new_workspace_buffer)
283+
self._batch_decode_wrapper.reset_workspace_buffer(
284+
float_workspace_buffer, int_workspace_buffer
285+
)
276286

277287
def begin_forward(
278288
self,
@@ -503,33 +513,43 @@ class BatchPrefillWithSharedPrefixPagedKVCacheWrapper:
503513
layers). This wrapper class manages the lifecycle of these data structures.
504514
"""
505515

506-
def __init__(self, workspace_buffer: torch.Tensor, kv_layout: str = "NHD") -> None:
516+
def __init__(
517+
self, float_workspace_buffer: torch.Tensor, kv_layout: str = "NHD"
518+
) -> None:
507519
r"""Constructor of :class:`BatchDecodeWithSharedPrefixPagedKVCacheWrapper`.
508520
509521
Parameters
510522
----------
511-
workspace_buffer : torch.Tensor
512-
The user reserved workspace buffer used to store auxiliary data structures,
513-
recommended size is 128MB, the device of the workspace buffer should be the
514-
same as the device of the input tensors.
523+
float_workspace_buffer : torch.Tensor
524+
The user reserved float workspace buffer used to store intermediate attention results
525+
in the split-k algorithm. The recommended size is 128MB, the device of the workspace
526+
buffer should be the same as the device of the input tensors.
515527
kv_layout : str
516528
The layout of the input k/v tensors, could be either ``NHD`` or ``HND``.
517529
"""
518530
self._batch_prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
519-
workspace_buffer, kv_layout
531+
float_workspace_buffer, kv_layout
520532
)
521533
self._kv_layout = kv_layout
522534

523-
def reset_workspace_buffer(self, new_workspace_buffer: torch.Tensor) -> None:
535+
def reset_workspace_buffer(
536+
self, float_workspace_buffer: torch.Tensor, int_workspace_buffer
537+
) -> None:
524538
r"""Reset the workspace buffer.
525539
526540
Parameters
527541
----------
528-
new_workspace_buffer : torch.Tensor
529-
The new workspace buffer, the device of the new workspace buffer should
542+
float_workspace_buffer : torch.Tensor
543+
The new float workspace buffer, the device of the new float workspace buffer should
544+
be the same as the device of the input tensors.
545+
546+
int_workspace_buffer : torch.Tensor
547+
The new int workspace buffer, the device of the new int workspace buffer should
530548
be the same as the device of the input tensors.
531549
"""
532-
self._batch_prefill_wrapper.reset_workspace_buffer(new_workspace_buffer)
550+
self._batch_prefill_wrapper.reset_workspace_buffer(
551+
float_workspace_buffer, int_workspace_buffer
552+
)
533553

534554
def begin_forward(
535555
self,

0 commit comments

Comments
 (0)