Skip to content

Commit 457a0ae

Browse files
authored
feat: expose decoupled kv-cache to pytorch api (#383)
Followup of #379
1 parent c6f20d1 commit 457a0ae

17 files changed

+966
-287
lines changed

docs/api/python/cascade.rst

-6
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,6 @@ Merge Attention States
2222
Cascade Attention
2323
-----------------
2424

25-
.. autosummary::
26-
:toctree: ../../generated
27-
28-
batch_decode_with_shared_prefix_padded_kv_cache
29-
30-
3125
Cascade Attention Wrapper Classes
3226
---------------------------------
3327

docs/api/python/decode.rst

-6
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,6 @@ Single Request Decoding
1616
Batch Decoding
1717
--------------
1818

19-
.. autosummary::
20-
:toctree: ../../generated
21-
22-
batch_decode_with_padded_kv_cache
23-
batch_decode_with_padded_kv_cache_return_lse
24-
2519
.. autoclass:: BatchDecodeWithPagedKVCacheWrapper
2620
:members:
2721

docs/tutorials/kv_layout.rst

+13-4
Original file line numberDiff line numberDiff line change
@@ -119,14 +119,23 @@ The overall ``kv_indptr`` array (with length ``num_requests+1``) can be computed
119119
The overall ``kv_page_indices`` array (with length ``kv_indptr[-1]``) is the concatenation of all requests' ``page_indices``.
120120
The overall ``kv_last_page_lens`` array (with length ``num_requests``) is the concatenation of all requests' ``last_page_length``.
121121

122-
The ``kv_data`` tensor is a 5-D tensor with shape (in ``NHD`` layout):
122+
The ``kv_data`` tensor could either be a single 5-D tensor or a tuple of 4-D tensors,
123+
when stored in a single tensor, ``kv_data`` has shape:
123124

124-
.. code::
125+
.. code:: python
125126
126-
(max_num_pages, 2, page_size, num_heads, head_dim)
127+
(max_num_pages, 2, page_size, num_heads, head_dim) # NHD layout
128+
(max_num_pages, 2, num_heads, page_size, head_dim) # HND layout
129+
130+
when stored in a tuple of tensors, ``kv_data = (k_data, v_data)``, and each one of them has shape:
131+
132+
.. code:: python
133+
134+
(max_num_pages, page_size, num_heads, head_dim) # NHD layout
135+
(max_num_pages, num_heads, page_size, head_dim) # HND layout
127136
128137
where ``max_num_pages`` is the maximum number of pages used by all requests, ``page_size`` is the number of tokens
129-
we fit into each page. ``2`` is the number of slots in each page (first one for keys, the second one for values).
138+
we fit into each page. ``2`` in single tensor storage means K/V (first one for keys, the second one for values).
130139

131140
FlashInfer APIs
132141
~~~~~~~~~~~~~~~

include/flashinfer/page.cuh

+42
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,48 @@ struct paged_kv_t {
115115
last_page_len(nullptr),
116116
rope_pos_offset(nullptr) {}
117117

118+
/*!
119+
* \brief Construct a paged key-value cache
120+
* \param num_heads The number of heads
121+
* \param page_size The size of each page
122+
* \param head_dim The dimension of each head
123+
* \param batch_size The batch size
124+
* \param layout The layout of last 3 dimensions in KV-Cache.
125+
* \param kv_data The flattened key-value cache
126+
* \param k_data The flattened key cache
127+
* \param v_data The flattened value cache
128+
* \param indices The page indices array
129+
* \param indptr The page indptr array
130+
* \param last_page_len The offset of the last page for each request in the batch
131+
* \param rope_pos_offset The start position of each request in the batch.
132+
* \note This constructor should only be used when page_storage == kIndices
133+
*/
134+
__host__ __forceinline__ paged_kv_t(uint32_t num_heads, uint32_t page_size, uint32_t head_dim,
135+
uint32_t batch_size, QKVLayout layout, DType* kv_data,
136+
DType* k_data, DType* v_data, IdType* indices, IdType* indptr,
137+
IdType* last_page_len, IdType* rope_pos_offset = nullptr)
138+
: num_heads(num_heads),
139+
page_size(page_size),
140+
head_dim(head_dim),
141+
batch_size(batch_size),
142+
indices(indices),
143+
indptr(indptr),
144+
last_page_len(last_page_len),
145+
rope_pos_offset(rope_pos_offset) {
146+
bool kv_defined = kv_data != nullptr;
147+
if (kv_defined) {
148+
stride_page = 2 * num_heads * page_size * head_dim;
149+
this->k_data = kv_data;
150+
this->v_data = kv_data + num_heads * page_size * head_dim;
151+
} else {
152+
stride_page = num_heads * page_size * head_dim;
153+
this->k_data = k_data;
154+
this->v_data = v_data;
155+
}
156+
stride_n = layout == QKVLayout::kHND ? head_dim : num_heads * head_dim;
157+
stride_h = layout == QKVLayout::kHND ? page_size * head_dim : head_dim;
158+
}
159+
118160
/*!
119161
* \brief Construct a paged key-value cache
120162
* \param num_heads The number of heads

python/csrc/batch_decode.cu

+61-19
Original file line numberDiff line numberDiff line change
@@ -105,40 +105,71 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize(
105105
}
106106

107107
std::vector<torch::Tensor> BatchDecodeWithPagedKVCachePyTorchWrapper::Forward(
108-
torch::Tensor q, torch::Tensor paged_kv_data, torch::Tensor paged_kv_indptr,
109-
torch::Tensor paged_kv_indices, torch::Tensor paged_kv_last_page_len,
110-
unsigned int pos_encoding_mode, float logits_soft_cap, float sm_scale, float rope_scale,
111-
float rope_theta, bool return_lse) {
108+
torch::Tensor q, std::optional<torch::Tensor> paged_kv_cache,
109+
std::optional<torch::Tensor> paged_k_cache, std::optional<torch::Tensor> paged_v_cache,
110+
torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices,
111+
torch::Tensor paged_kv_last_page_len, unsigned int pos_encoding_mode, float logits_soft_cap,
112+
float sm_scale, float rope_scale, float rope_theta, bool return_lse) {
112113
CHECK_INPUT(q);
113-
CHECK_INPUT(paged_kv_data);
114+
bool paged_kv_defined = paged_kv_cache.has_value();
115+
if (paged_kv_defined) {
116+
CHECK_INPUT(paged_kv_cache.value());
117+
} else {
118+
CHECK_INPUT(paged_k_cache.value());
119+
CHECK_INPUT(paged_v_cache.value());
120+
}
114121
CHECK_INPUT(paged_kv_indptr);
115122
CHECK_INPUT(paged_kv_indices);
116123
CHECK_INPUT(paged_kv_last_page_len);
117124
auto device = q.device();
118-
CHECK_EQ(paged_kv_data.device(), device);
125+
if (paged_kv_defined) {
126+
CHECK_EQ(paged_kv_cache->device(), device);
127+
} else {
128+
CHECK_EQ(paged_k_cache->device(), device);
129+
CHECK_EQ(paged_v_cache->device(), device);
130+
}
119131
CHECK_EQ(paged_kv_indices.device(), device);
120132
CHECK_EQ(paged_kv_indptr.device(), device);
121133
CHECK_EQ(paged_kv_last_page_len.device(), device);
122134
CHECK_DIM(3, q); // (B, H_qo, D)
123135
CHECK_DIM(1, paged_kv_last_page_len); // (B,)
124136
CHECK_DIM(1, paged_kv_indptr); // (B+1,)
125137
CHECK_DIM(1, paged_kv_indices); // (nnz,)
126-
// (num_max_pages, 2, H_kv, page_size, head_dim) for HND
127-
// (num_max_pages, 2, page_size, H_kv, head_dim) for NHD
128-
CHECK_DIM(5, paged_kv_data);
138+
if (paged_kv_defined) {
139+
// (num_max_pages, 2, H_kv, page_size, head_dim) for HND
140+
// (num_max_pages, 2, page_size, H_kv, head_dim) for NHD
141+
CHECK_DIM(5, paged_kv_cache.value());
142+
} else {
143+
// (num_max_pages, H_kv, page_size, head_dim) for HND
144+
// (num_max_pages, page_size, H_kv, head_dim) for NHD
145+
CHECK_DIM(4, paged_k_cache.value());
146+
CHECK_DIM(4, paged_v_cache.value());
147+
}
129148
int64_t batch_size = q.size(0);
130149
int64_t num_qo_heads = q.size(1);
131150
int64_t head_dim = q.size(2);
132151
int64_t num_kv_heads, page_size;
133-
if (kv_layout_ == QKVLayout::kHND) {
134-
num_kv_heads = paged_kv_data.size(2);
135-
page_size = paged_kv_data.size(3);
152+
if (paged_kv_defined) {
153+
CHECK_EQ(paged_kv_cache->size(1), 2);
154+
CHECK_EQ(paged_kv_cache->size(4), head_dim);
155+
if (kv_layout_ == QKVLayout::kHND) {
156+
num_kv_heads = paged_kv_cache->size(2);
157+
page_size = paged_kv_cache->size(3);
158+
} else {
159+
page_size = paged_kv_cache->size(2);
160+
num_kv_heads = paged_kv_cache->size(3);
161+
}
136162
} else {
137-
page_size = paged_kv_data.size(2);
138-
num_kv_heads = paged_kv_data.size(3);
163+
CHECK_EQ(paged_k_cache->size(3), head_dim);
164+
CHECK_EQ(paged_v_cache->size(3), head_dim);
165+
if (kv_layout_ == QKVLayout::kHND) {
166+
num_kv_heads = paged_k_cache->size(1);
167+
page_size = paged_k_cache->size(2);
168+
} else {
169+
page_size = paged_k_cache->size(1);
170+
num_kv_heads = paged_k_cache->size(2);
171+
}
139172
}
140-
CHECK_EQ(paged_kv_data.size(1), 2);
141-
CHECK_EQ(paged_kv_data.size(4), head_dim);
142173
CHECK_GE(paged_kv_indptr.size(0), batch_size + 1);
143174
CHECK_GE(paged_kv_last_page_len.size(0), batch_size);
144175
// TODO(Zihao): support dispatching to different data types
@@ -159,7 +190,8 @@ std::vector<torch::Tensor> BatchDecodeWithPagedKVCachePyTorchWrapper::Forward(
159190
logits_soft_cap > 0.f ? LogitsPostHook::kSoftCap : LogitsPostHook::kNone;
160191

161192
auto q_scalar_type = q.scalar_type();
162-
auto kv_scalar_type = paged_kv_data.scalar_type();
193+
auto kv_scalar_type =
194+
paged_kv_defined ? paged_kv_cache->scalar_type() : paged_k_cache->scalar_type();
163195

164196
if (q_scalar_type == kv_scalar_type) {
165197
DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q_scalar_type, qkv_type, [&] {
@@ -169,7 +201,12 @@ std::vector<torch::Tensor> BatchDecodeWithPagedKVCachePyTorchWrapper::Forward(
169201
PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] {
170202
paged_kv_t<PageStorage::kIndices, qkv_type, int32_t> paged_kv(
171203
num_kv_heads, page_size, head_dim, batch_size, kv_layout_,
172-
static_cast<qkv_type*>(paged_kv_data.data_ptr()),
204+
static_cast<qkv_type*>(paged_kv_cache.has_value() ? paged_kv_cache->data_ptr()
205+
: nullptr),
206+
static_cast<qkv_type*>(paged_k_cache.has_value() ? paged_k_cache->data_ptr()
207+
: nullptr),
208+
static_cast<qkv_type*>(paged_v_cache.has_value() ? paged_v_cache->data_ptr()
209+
: nullptr),
173210
static_cast<int32_t*>(paged_kv_indices.data_ptr()),
174211
static_cast<int32_t*>(paged_kv_indptr.data_ptr()),
175212
static_cast<int32_t*>(paged_kv_last_page_len.data_ptr()));
@@ -197,7 +234,12 @@ std::vector<torch::Tensor> BatchDecodeWithPagedKVCachePyTorchWrapper::Forward(
197234
PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] {
198235
paged_kv_t<PageStorage::kIndices, kv_type, int32_t> paged_kv(
199236
num_kv_heads, page_size, head_dim, batch_size, kv_layout_,
200-
static_cast<kv_type*>(paged_kv_data.data_ptr()),
237+
static_cast<kv_type*>(paged_kv_cache.has_value() ? paged_kv_cache->data_ptr()
238+
: nullptr),
239+
static_cast<kv_type*>(paged_k_cache.has_value() ? paged_k_cache->data_ptr()
240+
: nullptr),
241+
static_cast<kv_type*>(paged_v_cache.has_value() ? paged_v_cache->data_ptr()
242+
: nullptr),
201243
static_cast<int32_t*>(paged_kv_indices.data_ptr()),
202244
static_cast<int32_t*>(paged_kv_indptr.data_ptr()),
203245
static_cast<int32_t*>(paged_kv_last_page_len.data_ptr()));

0 commit comments

Comments
 (0)