@@ -45,15 +45,16 @@ at::Tensor BatchPrefillWithKVCachePlan(
45
45
at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr,
46
46
at::Tensor kv_len_arr, int64_t total_num_rows, int64_t batch_size, int64_t num_qo_heads,
47
47
int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, int64_t head_dim_qk,
48
- int64_t head_dim_vo, bool causal, int64_t cuda_stream ) {
48
+ int64_t head_dim_vo, bool causal) {
49
49
size_t float_workspace_size_in_bytes =
50
50
float_workspace_buffer.size (0 ) * float_workspace_buffer.element_size ();
51
51
size_t int_workspace_size_in_bytes =
52
52
int_workspace_buffer.size (0 ) * int_workspace_buffer.element_size ();
53
53
54
54
PrefillPlanInfo plan_info;
55
55
56
- cudaStream_t stream = reinterpret_cast <cudaStream_t>(cuda_stream);
56
+ const c10::cuda::OptionalCUDAGuard device_guard (float_workspace_buffer.device ());
57
+ const cudaStream_t stream = c10::cuda::getCurrentCUDAStream ();
57
58
cudaError_t status = PrefillPlan<IdType>(
58
59
float_workspace_buffer.data_ptr (), float_workspace_size_in_bytes,
59
60
int_workspace_buffer.data_ptr (), page_locked_int_workspace_buffer.data_ptr (),
@@ -72,8 +73,7 @@ void BatchPrefillWithRaggedKVCacheRun(at::Tensor float_workspace_buffer,
72
73
at::Tensor q, at::Tensor k, at::Tensor v,
73
74
at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor o,
74
75
std::optional<at::Tensor> maybe_lse, int64_t mask_mode_code,
75
- int64_t layout, int64_t window_left ADDITIONAL_FUNC_PARAMS,
76
- int64_t cuda_stream) {
76
+ int64_t layout, int64_t window_left ADDITIONAL_FUNC_PARAMS) {
77
77
PrefillPlanInfo plan_info;
78
78
plan_info.FromVector (tensor_to_vec (plan_info_vec));
79
79
QKVLayout kv_layout = static_cast <QKVLayout>(layout);
@@ -109,7 +109,8 @@ void BatchPrefillWithRaggedKVCacheRun(at::Tensor float_workspace_buffer,
109
109
auto q_scalar_type = q.scalar_type ();
110
110
auto kv_scalar_type = k.scalar_type ();
111
111
112
- cudaStream_t stream = reinterpret_cast <cudaStream_t>(cuda_stream);
112
+ const c10::cuda::OptionalCUDAGuard device_guard (float_workspace_buffer.device ());
113
+ const cudaStream_t stream = c10::cuda::getCurrentCUDAStream ();
113
114
114
115
DISPATCH_context (
115
116
DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE,
@@ -193,12 +194,14 @@ void BatchPrefillWithRaggedKVCacheRun(at::Tensor float_workspace_buffer,
193
194
});
194
195
}
195
196
196
- void BatchPrefillWithPagedKVCacheRun (
197
- at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, at::Tensor plan_info_vec,
198
- at::Tensor q, at::Tensor paged_k_cache, at::Tensor paged_v_cache, at::Tensor qo_indptr,
199
- at::Tensor paged_kv_indptr, at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len,
200
- at::Tensor o, std::optional<at::Tensor> maybe_lse, int64_t mask_mode_code, int64_t layout,
201
- int64_t window_left ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream) {
197
+ void BatchPrefillWithPagedKVCacheRun (at::Tensor float_workspace_buffer,
198
+ at::Tensor int_workspace_buffer, at::Tensor plan_info_vec,
199
+ at::Tensor q, at::Tensor paged_k_cache,
200
+ at::Tensor paged_v_cache, at::Tensor qo_indptr,
201
+ at::Tensor paged_kv_indptr, at::Tensor paged_kv_indices,
202
+ at::Tensor paged_kv_last_page_len, at::Tensor o,
203
+ std::optional<at::Tensor> maybe_lse, int64_t mask_mode_code,
204
+ int64_t layout, int64_t window_left ADDITIONAL_FUNC_PARAMS) {
202
205
PrefillPlanInfo plan_info;
203
206
plan_info.FromVector (tensor_to_vec (plan_info_vec));
204
207
QKVLayout kv_layout = static_cast <QKVLayout>(layout);
@@ -239,7 +242,8 @@ void BatchPrefillWithPagedKVCacheRun(
239
242
TORCH_CHECK (k_strides == v_strides, " k/v strides must be identical" );
240
243
kv_cache_strides = k_strides.data ();
241
244
242
- cudaStream_t stream = reinterpret_cast <cudaStream_t>(cuda_stream);
245
+ const c10::cuda::OptionalCUDAGuard device_guard (float_workspace_buffer.device ());
246
+ const cudaStream_t stream = c10::cuda::getCurrentCUDAStream ();
243
247
244
248
DISPATCH_context (
245
249
DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE,
0 commit comments