@@ -71,14 +71,14 @@ std::vector<int64_t> BatchPrefillWithKVCachePlan(
71
71
return plan_info.ToVector ();
72
72
}
73
73
74
- std::vector< torch::Tensor> BatchPrefillWithRaggedKVCacheRun (
74
+ torch::Tensor BatchPrefillWithRaggedKVCacheRun (
75
75
unsigned int mask_mode_code, torch::Tensor float_workspace_buffer,
76
76
torch::Tensor int_workspace_buffer, std::vector<int64_t > plan_info_vec, torch::Tensor q,
77
77
torch::Tensor k, torch::Tensor v, std::optional<torch::Tensor> maybe_custom_mask,
78
78
std::optional<torch::Tensor> maybe_alibi_slopes, torch::Tensor qo_indptr,
79
79
torch::Tensor kv_indptr, std::optional<torch::Tensor> maybe_qk_indptr, unsigned int layout,
80
80
int32_t window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta,
81
- bool return_lse ) {
81
+ std::optional<torch::Tensor> maybe_lse ) {
82
82
PrefillPlanInfo plan_info;
83
83
plan_info.FromVector (plan_info_vec);
84
84
QKVLayout kv_layout = static_cast <QKVLayout>(layout);
@@ -98,10 +98,11 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCacheRun(
98
98
auto device = float_workspace_buffer.device ();
99
99
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream (device.index ());
100
100
auto o = torch::empty_like (q, q.options ());
101
- int64_t nnz_qo = q.size (0 );
102
- torch::Tensor lse = torch::empty ({0 });
103
- if (return_lse) {
104
- lse = torch::empty ({nnz_qo, num_qo_heads}, q.options ().dtype (torch::kFloat32 ));
101
+ if (maybe_lse) {
102
+ const auto & lse = *maybe_lse;
103
+ TORCH_CHECK (lse.size (0 ) == q.size (0 ), lse.size (0 ), q.size (0 ));
104
+ TORCH_CHECK (lse.size (1 ) == q.size (1 ), lse.size (1 ), q.size (1 ));
105
+ TORCH_CHECK (lse.dtype () == torch::kFloat32 , " lse must be float32" );
105
106
}
106
107
107
108
void * float_buffer_ptr = float_workspace_buffer.data_ptr ();
@@ -140,7 +141,7 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCacheRun(
140
141
: nullptr ,
141
142
/* q_offset=*/ nullptr ,
142
143
/* k_rope_pos_offset=*/ nullptr , static_cast <DTypeO*>(o.data_ptr ()),
143
- /* lse=*/ return_lse ? static_cast <float *>(lse. data_ptr ()) : nullptr ,
144
+ /* lse=*/ (maybe_lse ? static_cast <float *>(maybe_lse-> data_ptr ()) : nullptr ) ,
144
145
/* alibi_slopes=*/ nullptr , num_qo_heads, num_kv_heads, q_stride_n, q_stride_h,
145
146
kv_stride_n, kv_stride_h, window_left, logits_soft_cap, sm_scale, rope_scale,
146
147
rope_theta);
@@ -187,22 +188,18 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCacheRun(
187
188
});
188
189
});
189
190
190
- if (return_lse) {
191
- return {o, lse};
192
- } else {
193
- return {o};
194
- }
191
+ return o;
195
192
}
196
193
197
- std::vector< torch::Tensor> BatchPrefillWithPagedKVCacheRun (
194
+ torch::Tensor BatchPrefillWithPagedKVCacheRun (
198
195
unsigned int mask_mode_code, torch::Tensor float_workspace_buffer,
199
196
torch::Tensor int_workspace_buffer, std::vector<int64_t > plan_info_vec, torch::Tensor q,
200
197
torch::Tensor paged_k_cache, torch::Tensor paged_v_cache,
201
198
std::optional<torch::Tensor> maybe_custom_mask, std::optional<torch::Tensor> maybe_alibi_slopes,
202
199
torch::Tensor qo_indptr, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices,
203
200
torch::Tensor paged_kv_last_page_len, std::optional<torch::Tensor> maybe_qk_indptr,
204
201
unsigned int layout, int32_t window_left, float logits_soft_cap, float sm_scale,
205
- float rope_scale, float rope_theta, bool return_lse ) {
202
+ float rope_scale, float rope_theta, std::optional<torch::Tensor> maybe_lse ) {
206
203
PrefillPlanInfo plan_info;
207
204
plan_info.FromVector (plan_info_vec);
208
205
QKVLayout kv_layout = static_cast <QKVLayout>(layout);
@@ -221,10 +218,11 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCacheRun(
221
218
222
219
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream (device.index ());
223
220
auto o = torch::empty_like (q, q.options ());
224
- int64_t nnz_qo = q.size (0 );
225
- torch::Tensor lse = torch::empty ({0 });
226
- if (return_lse) {
227
- lse = torch::empty ({nnz_qo, num_qo_heads}, q.options ().dtype (torch::kFloat32 ));
221
+ if (maybe_lse) {
222
+ const auto & lse = *maybe_lse;
223
+ TORCH_CHECK (lse.size (0 ) == q.size (0 ), lse.size (0 ), q.size (0 ));
224
+ TORCH_CHECK (lse.size (1 ) == q.size (1 ), lse.size (1 ), q.size (1 ));
225
+ TORCH_CHECK (lse.dtype () == torch::kFloat32 , " lse must be float32" );
228
226
}
229
227
230
228
void * float_buffer_ptr = static_cast <void *>(float_workspace_buffer.data_ptr ());
@@ -277,7 +275,7 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCacheRun(
277
275
maybe_qk_indptr.has_value () ? static_cast <IdType*>(maybe_qk_indptr->data_ptr ())
278
276
: nullptr ,
279
277
/* q_offset=*/ nullptr , static_cast <DTypeO*>(o.data_ptr ()),
280
- /* lse=*/ return_lse ? static_cast <float *>(lse. data_ptr ()) : nullptr ,
278
+ /* lse=*/ (maybe_lse ? static_cast <float *>(maybe_lse-> data_ptr ()) : nullptr ) ,
281
279
/* alibi_slopes=*/ nullptr , num_qo_heads, q_stride_n, q_stride_h, window_left,
282
280
logits_soft_cap, sm_scale, rope_scale, rope_theta);
283
281
@@ -323,9 +321,5 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCacheRun(
323
321
});
324
322
});
325
323
326
- if (return_lse) {
327
- return {o, lse};
328
- } else {
329
- return {o};
330
- }
324
+ return o;
331
325
}
0 commit comments