@@ -24,17 +24,18 @@ void BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward(
24
24
torch::Tensor workspace_buffer, torch::Tensor qo_indptr, torch::Tensor paged_kv_indptr,
25
25
unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads,
26
26
unsigned int head_dim, unsigned int page_size, torch::Tensor empty_q_data) {
27
+ CHECK_INPUT (workspace_buffer);
27
28
// NOTE(Zihao): not necessary to be a CUDA tensor
28
29
CHECK_CONTIGUOUS (qo_indptr);
29
- CHECK_CONTIGUOUS (workspace_buffer );
30
+ CHECK_CONTIGUOUS (paged_kv_indptr );
30
31
CHECK_GQA_HEAD_DIVISIBLE (num_qo_heads, num_kv_heads);
31
32
CHECK_DIM (1 , qo_indptr);
32
33
CHECK_DIM (1 , workspace_buffer);
33
34
qo_indptr = qo_indptr.to (torch::kCPU ).to (torch::kInt32 );
34
35
paged_kv_indptr = paged_kv_indptr.to (torch::kCPU ).to (torch::kInt32 );
35
-
36
+ auto device = workspace_buffer. device ();
36
37
size_t workspace_size_in_bytes = workspace_buffer.size (0 ) * workspace_buffer.element_size ();
37
- cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream ();
38
+ cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream (device. index () );
38
39
handler_->SetCUDAStream (torch_current_stream);
39
40
40
41
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16 (empty_q_data.scalar_type (), q_type, [&] {
@@ -68,6 +69,12 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCachePyTorchWrapper::Forward(
68
69
CHECK_INPUT (paged_kv_indptr);
69
70
CHECK_INPUT (paged_kv_indices);
70
71
CHECK_INPUT (paged_kv_last_page_len);
72
+ auto device = q.device ();
73
+ CHECK_EQ (device, qo_indptr.device ());
74
+ CHECK_EQ (device, paged_kv_data.device ());
75
+ CHECK_EQ (device, paged_kv_indptr.device ());
76
+ CHECK_EQ (device, paged_kv_indices.device ());
77
+ CHECK_EQ (device, paged_kv_last_page_len.device ());
71
78
CHECK_DIM (3 , q); // (nnz_qo, H_qo, D)
72
79
CHECK_DIM (1 , qo_indptr); // (B + 1,)
73
80
// [max_num_pages, 2, num_kv_heads, page_size, head_dim] for HND
@@ -100,7 +107,7 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCachePyTorchWrapper::Forward(
100
107
paged_kv_indices = paged_kv_indices.to (torch::kInt32 );
101
108
paged_kv_last_page_len = paged_kv_last_page_len.to (torch::kInt32 );
102
109
103
- cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream ();
110
+ cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream (device. index () );
104
111
torch::Tensor o = torch::empty_like (q, q.options ());
105
112
torch::Tensor lse = torch::empty ({0 });
106
113
if (return_lse) {
@@ -171,6 +178,14 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCu
171
178
CHECK_INPUT (paged_kv_last_page_len);
172
179
CHECK_INPUT (custom_mask);
173
180
CHECK_INPUT (qk_indptr);
181
+ auto device = q.device ();
182
+ CHECK_EQ (device, qo_indptr.device ());
183
+ CHECK_EQ (device, paged_kv_data.device ());
184
+ CHECK_EQ (device, paged_kv_indptr.device ());
185
+ CHECK_EQ (device, paged_kv_indices.device ());
186
+ CHECK_EQ (device, paged_kv_last_page_len.device ());
187
+ CHECK_EQ (device, custom_mask.device ());
188
+ CHECK_EQ (device, qk_indptr.device ());
174
189
CHECK_DIM (3 , q); // (nnz_qo, H_qo, D)
175
190
CHECK_DIM (1 , qo_indptr); // (B + 1,)
176
191
// [max_num_pages, 2, num_kv_heads, page_size, head_dim] for HND
@@ -207,7 +222,7 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCu
207
222
paged_kv_last_page_len = paged_kv_last_page_len.to (torch::kInt32 );
208
223
qk_indptr = qk_indptr.to (torch::kInt32 );
209
224
210
- cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream ();
225
+ cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream (device. index () );
211
226
torch::Tensor o = torch::empty_like (q, q.options ());
212
227
torch::Tensor lse = torch::empty ({0 });
213
228
if (return_lse) {
@@ -267,17 +282,17 @@ void BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward(
267
282
torch::Tensor workspace_buffer, torch::Tensor qo_indptr, torch::Tensor kv_indptr,
268
283
unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads,
269
284
unsigned int head_dim, torch::Tensor empty_q_data) {
285
+ CHECK_INPUT (workspace_buffer);
270
286
// NOTE(Zihao): not necessary to be a CUDA tensor
271
287
CHECK_CONTIGUOUS (qo_indptr);
272
- CHECK_CONTIGUOUS (workspace_buffer);
273
288
CHECK_GQA_HEAD_DIVISIBLE (num_qo_heads, num_kv_heads);
274
289
CHECK_DIM (1 , qo_indptr);
275
290
CHECK_DIM (1 , workspace_buffer);
276
-
277
291
qo_indptr = qo_indptr.to (torch::kCPU ).to (torch::kInt32 );
278
292
kv_indptr = kv_indptr.to (torch::kCPU ).to (torch::kInt32 );
279
293
size_t workspace_size_in_bytes = workspace_buffer.size (0 ) * workspace_buffer.element_size ();
280
- cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream ();
294
+ auto device = workspace_buffer.device ();
295
+ cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream (device.index ());
281
296
handler_->SetCUDAStream (torch_current_stream);
282
297
283
298
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16 (empty_q_data.scalar_type (), q_type, [&] {
@@ -309,6 +324,11 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward(
309
324
CHECK_INPUT (k);
310
325
CHECK_INPUT (v);
311
326
CHECK_INPUT (kv_indptr);
327
+ auto device = q.device ();
328
+ CHECK_EQ (device, qo_indptr.device ());
329
+ CHECK_EQ (device, k.device ());
330
+ CHECK_EQ (device, v.device ());
331
+ CHECK_EQ (device, kv_indptr.device ());
312
332
CHECK_DIM (3 , q); // (nnz_qo, H_qo, D)
313
333
CHECK_DIM (1 , qo_indptr); // (B + 1,)
314
334
CHECK_DIM (3 , k); // (nnz_kv, H_kv, D) if NHD else (H_kv, nnz_kv, D)
@@ -330,7 +350,7 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward(
330
350
qo_indptr = qo_indptr.to (torch::kInt32 );
331
351
kv_indptr = kv_indptr.to (torch::kInt32 );
332
352
333
- cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream ();
353
+ cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream (device. index () );
334
354
torch::Tensor o = torch::empty_like (q, q.options ());
335
355
torch::Tensor lse = torch::empty ({0 });
336
356
if (return_lse) {
@@ -396,6 +416,13 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardC
396
416
CHECK_INPUT (kv_indptr);
397
417
CHECK_INPUT (custom_mask);
398
418
CHECK_INPUT (qk_indptr);
419
+ auto device = q.device ();
420
+ CHECK_EQ (device, qo_indptr.device ());
421
+ CHECK_EQ (device, k.device ());
422
+ CHECK_EQ (device, v.device ());
423
+ CHECK_EQ (device, kv_indptr.device ());
424
+ CHECK_EQ (device, custom_mask.device ());
425
+ CHECK_EQ (device, qk_indptr.device ());
399
426
CHECK_DIM (3 , q); // (nnz_qo, H_qo, D)
400
427
CHECK_DIM (1 , qo_indptr); // (B + 1,)
401
428
CHECK_DIM (3 , k); // (nnz_kv, H_kv, D) if NHD else (H_kv, nnz_kv, D)
@@ -421,7 +448,7 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardC
421
448
kv_indptr = kv_indptr.to (torch::kInt32 );
422
449
qk_indptr = qk_indptr.to (torch::kInt32 );
423
450
424
- cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream ();
451
+ cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream (device. index () );
425
452
torch::Tensor o = torch::empty_like (q, q.options ());
426
453
torch::Tensor lse = torch::empty ({0 });
427
454
if (return_lse) {
0 commit comments