Skip to content

Commit ea86f81

Browse files
authored
bugfix: fix batch_prefill.cu in AOT mode after #554 (#559)
#554 didn't update the `batch_prefill.cu` (which was used in AOT mode) according to the API change. This PR fixes the issue. cc @abcdabcd987
1 parent 6227562 commit ea86f81

File tree

1 file changed

+18
-24
lines changed

1 file changed

+18
-24
lines changed

flashinfer-aot/csrc_aot/batch_prefill.cu

+18-24
Original file line numberDiff line numberDiff line change
@@ -71,14 +71,14 @@ std::vector<int64_t> BatchPrefillWithKVCachePlan(
7171
return plan_info.ToVector();
7272
}
7373

74-
std::vector<torch::Tensor> BatchPrefillWithRaggedKVCacheRun(
74+
torch::Tensor BatchPrefillWithRaggedKVCacheRun(
7575
unsigned int mask_mode_code, torch::Tensor float_workspace_buffer,
7676
torch::Tensor int_workspace_buffer, std::vector<int64_t> plan_info_vec, torch::Tensor q,
7777
torch::Tensor k, torch::Tensor v, std::optional<torch::Tensor> maybe_custom_mask,
7878
std::optional<torch::Tensor> maybe_alibi_slopes, torch::Tensor qo_indptr,
7979
torch::Tensor kv_indptr, std::optional<torch::Tensor> maybe_qk_indptr, unsigned int layout,
8080
int32_t window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta,
81-
bool return_lse) {
81+
std::optional<torch::Tensor> maybe_lse) {
8282
PrefillPlanInfo plan_info;
8383
plan_info.FromVector(plan_info_vec);
8484
QKVLayout kv_layout = static_cast<QKVLayout>(layout);
@@ -98,10 +98,11 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCacheRun(
9898
auto device = float_workspace_buffer.device();
9999
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
100100
auto o = torch::empty_like(q, q.options());
101-
int64_t nnz_qo = q.size(0);
102-
torch::Tensor lse = torch::empty({0});
103-
if (return_lse) {
104-
lse = torch::empty({nnz_qo, num_qo_heads}, q.options().dtype(torch::kFloat32));
101+
if (maybe_lse) {
102+
const auto& lse = *maybe_lse;
103+
TORCH_CHECK(lse.size(0) == q.size(0), lse.size(0), q.size(0));
104+
TORCH_CHECK(lse.size(1) == q.size(1), lse.size(1), q.size(1));
105+
TORCH_CHECK(lse.dtype() == torch::kFloat32, "lse must be float32");
105106
}
106107

107108
void* float_buffer_ptr = float_workspace_buffer.data_ptr();
@@ -140,7 +141,7 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCacheRun(
140141
: nullptr,
141142
/*q_offset=*/nullptr,
142143
/*k_rope_pos_offset=*/nullptr, static_cast<DTypeO*>(o.data_ptr()),
143-
/*lse=*/return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr,
144+
/*lse=*/(maybe_lse ? static_cast<float*>(maybe_lse->data_ptr()) : nullptr),
144145
/*alibi_slopes=*/nullptr, num_qo_heads, num_kv_heads, q_stride_n, q_stride_h,
145146
kv_stride_n, kv_stride_h, window_left, logits_soft_cap, sm_scale, rope_scale,
146147
rope_theta);
@@ -187,22 +188,18 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCacheRun(
187188
});
188189
});
189190

190-
if (return_lse) {
191-
return {o, lse};
192-
} else {
193-
return {o};
194-
}
191+
return o;
195192
}
196193

197-
std::vector<torch::Tensor> BatchPrefillWithPagedKVCacheRun(
194+
torch::Tensor BatchPrefillWithPagedKVCacheRun(
198195
unsigned int mask_mode_code, torch::Tensor float_workspace_buffer,
199196
torch::Tensor int_workspace_buffer, std::vector<int64_t> plan_info_vec, torch::Tensor q,
200197
torch::Tensor paged_k_cache, torch::Tensor paged_v_cache,
201198
std::optional<torch::Tensor> maybe_custom_mask, std::optional<torch::Tensor> maybe_alibi_slopes,
202199
torch::Tensor qo_indptr, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices,
203200
torch::Tensor paged_kv_last_page_len, std::optional<torch::Tensor> maybe_qk_indptr,
204201
unsigned int layout, int32_t window_left, float logits_soft_cap, float sm_scale,
205-
float rope_scale, float rope_theta, bool return_lse) {
202+
float rope_scale, float rope_theta, std::optional<torch::Tensor> maybe_lse) {
206203
PrefillPlanInfo plan_info;
207204
plan_info.FromVector(plan_info_vec);
208205
QKVLayout kv_layout = static_cast<QKVLayout>(layout);
@@ -221,10 +218,11 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCacheRun(
221218

222219
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
223220
auto o = torch::empty_like(q, q.options());
224-
int64_t nnz_qo = q.size(0);
225-
torch::Tensor lse = torch::empty({0});
226-
if (return_lse) {
227-
lse = torch::empty({nnz_qo, num_qo_heads}, q.options().dtype(torch::kFloat32));
221+
if (maybe_lse) {
222+
const auto& lse = *maybe_lse;
223+
TORCH_CHECK(lse.size(0) == q.size(0), lse.size(0), q.size(0));
224+
TORCH_CHECK(lse.size(1) == q.size(1), lse.size(1), q.size(1));
225+
TORCH_CHECK(lse.dtype() == torch::kFloat32, "lse must be float32");
228226
}
229227

230228
void* float_buffer_ptr = static_cast<void*>(float_workspace_buffer.data_ptr());
@@ -277,7 +275,7 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCacheRun(
277275
maybe_qk_indptr.has_value() ? static_cast<IdType*>(maybe_qk_indptr->data_ptr())
278276
: nullptr,
279277
/*q_offset=*/nullptr, static_cast<DTypeO*>(o.data_ptr()),
280-
/*lse=*/return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr,
278+
/*lse=*/(maybe_lse ? static_cast<float*>(maybe_lse->data_ptr()) : nullptr),
281279
/*alibi_slopes=*/nullptr, num_qo_heads, q_stride_n, q_stride_h, window_left,
282280
logits_soft_cap, sm_scale, rope_scale, rope_theta);
283281

@@ -323,9 +321,5 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCacheRun(
323321
});
324322
});
325323

326-
if (return_lse) {
327-
return {o, lse};
328-
} else {
329-
return {o};
330-
}
324+
return o;
331325
}

0 commit comments

Comments
 (0)