@@ -105,40 +105,71 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize(
105
105
}
106
106
107
107
std::vector<torch::Tensor> BatchDecodeWithPagedKVCachePyTorchWrapper::Forward (
108
- torch::Tensor q, torch::Tensor paged_kv_data, torch::Tensor paged_kv_indptr,
109
- torch::Tensor paged_kv_indices, torch::Tensor paged_kv_last_page_len,
110
- unsigned int pos_encoding_mode, float logits_soft_cap, float sm_scale, float rope_scale,
111
- float rope_theta, bool return_lse) {
108
+ torch::Tensor q, std::optional<torch::Tensor> paged_kv_cache,
109
+ std::optional<torch::Tensor> paged_k_cache, std::optional<torch::Tensor> paged_v_cache,
110
+ torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices,
111
+ torch::Tensor paged_kv_last_page_len, unsigned int pos_encoding_mode, float logits_soft_cap,
112
+ float sm_scale, float rope_scale, float rope_theta, bool return_lse) {
112
113
CHECK_INPUT (q);
113
- CHECK_INPUT (paged_kv_data);
114
+ bool paged_kv_defined = paged_kv_cache.has_value ();
115
+ if (paged_kv_defined) {
116
+ CHECK_INPUT (paged_kv_cache.value ());
117
+ } else {
118
+ CHECK_INPUT (paged_k_cache.value ());
119
+ CHECK_INPUT (paged_v_cache.value ());
120
+ }
114
121
CHECK_INPUT (paged_kv_indptr);
115
122
CHECK_INPUT (paged_kv_indices);
116
123
CHECK_INPUT (paged_kv_last_page_len);
117
124
auto device = q.device ();
118
- CHECK_EQ (paged_kv_data.device (), device);
125
+ if (paged_kv_defined) {
126
+ CHECK_EQ (paged_kv_cache->device (), device);
127
+ } else {
128
+ CHECK_EQ (paged_k_cache->device (), device);
129
+ CHECK_EQ (paged_v_cache->device (), device);
130
+ }
119
131
CHECK_EQ (paged_kv_indices.device (), device);
120
132
CHECK_EQ (paged_kv_indptr.device (), device);
121
133
CHECK_EQ (paged_kv_last_page_len.device (), device);
122
134
CHECK_DIM (3 , q); // (B, H_qo, D)
123
135
CHECK_DIM (1 , paged_kv_last_page_len); // (B,)
124
136
CHECK_DIM (1 , paged_kv_indptr); // (B+1,)
125
137
CHECK_DIM (1 , paged_kv_indices); // (nnz,)
126
- // (num_max_pages, 2, H_kv, page_size, head_dim) for HND
127
- // (num_max_pages, 2, page_size, H_kv, head_dim) for NHD
128
- CHECK_DIM (5 , paged_kv_data);
138
+ if (paged_kv_defined) {
139
+ // (num_max_pages, 2, H_kv, page_size, head_dim) for HND
140
+ // (num_max_pages, 2, page_size, H_kv, head_dim) for NHD
141
+ CHECK_DIM (5 , paged_kv_cache.value ());
142
+ } else {
143
+ // (num_max_pages, H_kv, page_size, head_dim) for HND
144
+ // (num_max_pages, page_size, H_kv, head_dim) for NHD
145
+ CHECK_DIM (4 , paged_k_cache.value ());
146
+ CHECK_DIM (4 , paged_v_cache.value ());
147
+ }
129
148
int64_t batch_size = q.size (0 );
130
149
int64_t num_qo_heads = q.size (1 );
131
150
int64_t head_dim = q.size (2 );
132
151
int64_t num_kv_heads, page_size;
133
- if (kv_layout_ == QKVLayout::kHND ) {
134
- num_kv_heads = paged_kv_data.size (2 );
135
- page_size = paged_kv_data.size (3 );
152
+ if (paged_kv_defined) {
153
+ CHECK_EQ (paged_kv_cache->size (1 ), 2 );
154
+ CHECK_EQ (paged_kv_cache->size (4 ), head_dim);
155
+ if (kv_layout_ == QKVLayout::kHND ) {
156
+ num_kv_heads = paged_kv_cache->size (2 );
157
+ page_size = paged_kv_cache->size (3 );
158
+ } else {
159
+ page_size = paged_kv_cache->size (2 );
160
+ num_kv_heads = paged_kv_cache->size (3 );
161
+ }
136
162
} else {
137
- page_size = paged_kv_data.size (2 );
138
- num_kv_heads = paged_kv_data.size (3 );
163
+ CHECK_EQ (paged_k_cache->size (3 ), head_dim);
164
+ CHECK_EQ (paged_v_cache->size (3 ), head_dim);
165
+ if (kv_layout_ == QKVLayout::kHND ) {
166
+ num_kv_heads = paged_k_cache->size (1 );
167
+ page_size = paged_k_cache->size (2 );
168
+ } else {
169
+ page_size = paged_k_cache->size (1 );
170
+ num_kv_heads = paged_k_cache->size (2 );
171
+ }
139
172
}
140
- CHECK_EQ (paged_kv_data.size (1 ), 2 );
141
- CHECK_EQ (paged_kv_data.size (4 ), head_dim);
142
173
CHECK_GE (paged_kv_indptr.size (0 ), batch_size + 1 );
143
174
CHECK_GE (paged_kv_last_page_len.size (0 ), batch_size);
144
175
// TODO(Zihao): support dispatching to different data types
@@ -159,7 +190,8 @@ std::vector<torch::Tensor> BatchDecodeWithPagedKVCachePyTorchWrapper::Forward(
159
190
logits_soft_cap > 0 .f ? LogitsPostHook::kSoftCap : LogitsPostHook::kNone ;
160
191
161
192
auto q_scalar_type = q.scalar_type ();
162
- auto kv_scalar_type = paged_kv_data.scalar_type ();
193
+ auto kv_scalar_type =
194
+ paged_kv_defined ? paged_kv_cache->scalar_type () : paged_k_cache->scalar_type ();
163
195
164
196
if (q_scalar_type == kv_scalar_type) {
165
197
DISPATCH_PYTORCH_DTYPE_TO_CTYPE (q_scalar_type, qkv_type, [&] {
@@ -169,7 +201,12 @@ std::vector<torch::Tensor> BatchDecodeWithPagedKVCachePyTorchWrapper::Forward(
169
201
PosEncodingMode (pos_encoding_mode), POS_ENCODING_MODE, [&] {
170
202
paged_kv_t <PageStorage::kIndices , qkv_type, int32_t > paged_kv (
171
203
num_kv_heads, page_size, head_dim, batch_size, kv_layout_,
172
- static_cast <qkv_type*>(paged_kv_data.data_ptr ()),
204
+ static_cast <qkv_type*>(paged_kv_cache.has_value () ? paged_kv_cache->data_ptr ()
205
+ : nullptr ),
206
+ static_cast <qkv_type*>(paged_k_cache.has_value () ? paged_k_cache->data_ptr ()
207
+ : nullptr ),
208
+ static_cast <qkv_type*>(paged_v_cache.has_value () ? paged_v_cache->data_ptr ()
209
+ : nullptr ),
173
210
static_cast <int32_t *>(paged_kv_indices.data_ptr ()),
174
211
static_cast <int32_t *>(paged_kv_indptr.data_ptr ()),
175
212
static_cast <int32_t *>(paged_kv_last_page_len.data_ptr ()));
@@ -197,7 +234,12 @@ std::vector<torch::Tensor> BatchDecodeWithPagedKVCachePyTorchWrapper::Forward(
197
234
PosEncodingMode (pos_encoding_mode), POS_ENCODING_MODE, [&] {
198
235
paged_kv_t <PageStorage::kIndices , kv_type, int32_t > paged_kv (
199
236
num_kv_heads, page_size, head_dim, batch_size, kv_layout_,
200
- static_cast <kv_type*>(paged_kv_data.data_ptr ()),
237
+ static_cast <kv_type*>(paged_kv_cache.has_value () ? paged_kv_cache->data_ptr ()
238
+ : nullptr ),
239
+ static_cast <kv_type*>(paged_k_cache.has_value () ? paged_k_cache->data_ptr ()
240
+ : nullptr ),
241
+ static_cast <kv_type*>(paged_v_cache.has_value () ? paged_v_cache->data_ptr ()
242
+ : nullptr ),
201
243
static_cast <int32_t *>(paged_kv_indices.data_ptr ()),
202
244
static_cast <int32_t *>(paged_kv_indptr.data_ptr ()),
203
245
static_cast <int32_t *>(paged_kv_last_page_len.data_ptr ()));
0 commit comments