Skip to content

Commit 1b84fab

Browse files
authored
bugfix: check gpu id in PyTorch APIs and use input tensor's gpu default stream (#361)
This PR fixes #349 by using the default stream of input tensors' device instead of the default stream of default device (which might be different to input tensors' device). This PR also adds sanity check on input tensors device id (all input tensors must be on the same GPU).
1 parent 3536198 commit 1b84fab

10 files changed

+129
-44
lines changed

Diff for: python/csrc/batch_decode.cu

+9-3
Original file line numberDiff line numberDiff line change
@@ -25,18 +25,19 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward(
2525
unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads,
2626
unsigned int head_dim, unsigned int page_size, unsigned int pos_encoding_mode,
2727
float logits_soft_cap, torch::Tensor empty_q_data, torch::Tensor empty_kv_data) {
28+
CHECK_INPUT(workspace_buffer);
2829
// NOTE(zihao): not necessary to be CUDA tensor
2930
CHECK_CONTIGUOUS(indptr);
3031
CHECK_CONTIGUOUS(last_page_len);
31-
CHECK_CONTIGUOUS(workspace_buffer);
3232
CHECK_DIM(1, indptr);
3333
CHECK_DIM(1, last_page_len);
3434
CHECK_DIM(1, workspace_buffer);
3535
CHECK_EQ(indptr.scalar_type(), torch::kInt32);
3636
CHECK_EQ(indptr.scalar_type(), torch::kInt32);
3737
CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads);
3838
size_t workspace_size_in_bytes = workspace_buffer.size(0) * workspace_buffer.element_size();
39-
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
39+
auto device = workspace_buffer.device();
40+
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
4041
handler_->SetCUDAStream(torch_current_stream);
4142
indptr = indptr.to(torch::kCPU);
4243
last_page_len = last_page_len.to(torch::kCPU);
@@ -116,6 +117,11 @@ std::vector<torch::Tensor> BatchDecodeWithPagedKVCachePyTorchWrapper::Forward(
116117
CHECK_INPUT(paged_kv_indptr);
117118
CHECK_INPUT(paged_kv_indices);
118119
CHECK_INPUT(paged_kv_last_page_len);
120+
auto device = q.device();
121+
CHECK_EQ(paged_kv_data.device(), device);
122+
CHECK_EQ(paged_kv_indices.device(), device);
123+
CHECK_EQ(paged_kv_indptr.device(), device);
124+
CHECK_EQ(paged_kv_last_page_len.device(), device);
119125
CHECK_DIM(3, q); // (B, H_qo, D)
120126
CHECK_DIM(1, paged_kv_last_page_len); // (B,)
121127
CHECK_DIM(1, paged_kv_indptr); // (B+1,)
@@ -144,7 +150,7 @@ std::vector<torch::Tensor> BatchDecodeWithPagedKVCachePyTorchWrapper::Forward(
144150
CHECK_EQ(paged_kv_last_page_len.scalar_type(), torch::kInt32);
145151
CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads);
146152

147-
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
153+
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
148154
torch::Tensor o = torch::empty_like(q);
149155
torch::Tensor lse;
150156
if (return_lse) {

Diff for: python/csrc/batch_prefill.cu

+37-10
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,18 @@ void BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward(
2424
torch::Tensor workspace_buffer, torch::Tensor qo_indptr, torch::Tensor paged_kv_indptr,
2525
unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads,
2626
unsigned int head_dim, unsigned int page_size, torch::Tensor empty_q_data) {
27+
CHECK_INPUT(workspace_buffer);
2728
// NOTE(Zihao): not necessary to be a CUDA tensor
2829
CHECK_CONTIGUOUS(qo_indptr);
29-
CHECK_CONTIGUOUS(workspace_buffer);
30+
CHECK_CONTIGUOUS(paged_kv_indptr);
3031
CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads);
3132
CHECK_DIM(1, qo_indptr);
3233
CHECK_DIM(1, workspace_buffer);
3334
qo_indptr = qo_indptr.to(torch::kCPU).to(torch::kInt32);
3435
paged_kv_indptr = paged_kv_indptr.to(torch::kCPU).to(torch::kInt32);
35-
36+
auto device = workspace_buffer.device();
3637
size_t workspace_size_in_bytes = workspace_buffer.size(0) * workspace_buffer.element_size();
37-
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
38+
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
3839
handler_->SetCUDAStream(torch_current_stream);
3940

4041
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(empty_q_data.scalar_type(), q_type, [&] {
@@ -68,6 +69,12 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCachePyTorchWrapper::Forward(
6869
CHECK_INPUT(paged_kv_indptr);
6970
CHECK_INPUT(paged_kv_indices);
7071
CHECK_INPUT(paged_kv_last_page_len);
72+
auto device = q.device();
73+
CHECK_EQ(device, qo_indptr.device());
74+
CHECK_EQ(device, paged_kv_data.device());
75+
CHECK_EQ(device, paged_kv_indptr.device());
76+
CHECK_EQ(device, paged_kv_indices.device());
77+
CHECK_EQ(device, paged_kv_last_page_len.device());
7178
CHECK_DIM(3, q); // (nnz_qo, H_qo, D)
7279
CHECK_DIM(1, qo_indptr); // (B + 1,)
7380
// [max_num_pages, 2, num_kv_heads, page_size, head_dim] for HND
@@ -100,7 +107,7 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCachePyTorchWrapper::Forward(
100107
paged_kv_indices = paged_kv_indices.to(torch::kInt32);
101108
paged_kv_last_page_len = paged_kv_last_page_len.to(torch::kInt32);
102109

103-
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
110+
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
104111
torch::Tensor o = torch::empty_like(q, q.options());
105112
torch::Tensor lse = torch::empty({0});
106113
if (return_lse) {
@@ -171,6 +178,14 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCu
171178
CHECK_INPUT(paged_kv_last_page_len);
172179
CHECK_INPUT(custom_mask);
173180
CHECK_INPUT(qk_indptr);
181+
auto device = q.device();
182+
CHECK_EQ(device, qo_indptr.device());
183+
CHECK_EQ(device, paged_kv_data.device());
184+
CHECK_EQ(device, paged_kv_indptr.device());
185+
CHECK_EQ(device, paged_kv_indices.device());
186+
CHECK_EQ(device, paged_kv_last_page_len.device());
187+
CHECK_EQ(device, custom_mask.device());
188+
CHECK_EQ(device, qk_indptr.device());
174189
CHECK_DIM(3, q); // (nnz_qo, H_qo, D)
175190
CHECK_DIM(1, qo_indptr); // (B + 1,)
176191
// [max_num_pages, 2, num_kv_heads, page_size, head_dim] for HND
@@ -207,7 +222,7 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCu
207222
paged_kv_last_page_len = paged_kv_last_page_len.to(torch::kInt32);
208223
qk_indptr = qk_indptr.to(torch::kInt32);
209224

210-
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
225+
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
211226
torch::Tensor o = torch::empty_like(q, q.options());
212227
torch::Tensor lse = torch::empty({0});
213228
if (return_lse) {
@@ -267,17 +282,17 @@ void BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward(
267282
torch::Tensor workspace_buffer, torch::Tensor qo_indptr, torch::Tensor kv_indptr,
268283
unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads,
269284
unsigned int head_dim, torch::Tensor empty_q_data) {
285+
CHECK_INPUT(workspace_buffer);
270286
// NOTE(Zihao): not necessary to be a CUDA tensor
271287
CHECK_CONTIGUOUS(qo_indptr);
272-
CHECK_CONTIGUOUS(workspace_buffer);
273288
CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads);
274289
CHECK_DIM(1, qo_indptr);
275290
CHECK_DIM(1, workspace_buffer);
276-
277291
qo_indptr = qo_indptr.to(torch::kCPU).to(torch::kInt32);
278292
kv_indptr = kv_indptr.to(torch::kCPU).to(torch::kInt32);
279293
size_t workspace_size_in_bytes = workspace_buffer.size(0) * workspace_buffer.element_size();
280-
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
294+
auto device = workspace_buffer.device();
295+
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
281296
handler_->SetCUDAStream(torch_current_stream);
282297

283298
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(empty_q_data.scalar_type(), q_type, [&] {
@@ -309,6 +324,11 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward(
309324
CHECK_INPUT(k);
310325
CHECK_INPUT(v);
311326
CHECK_INPUT(kv_indptr);
327+
auto device = q.device();
328+
CHECK_EQ(device, qo_indptr.device());
329+
CHECK_EQ(device, k.device());
330+
CHECK_EQ(device, v.device());
331+
CHECK_EQ(device, kv_indptr.device());
312332
CHECK_DIM(3, q); // (nnz_qo, H_qo, D)
313333
CHECK_DIM(1, qo_indptr); // (B + 1,)
314334
CHECK_DIM(3, k); // (nnz_kv, H_kv, D) if NHD else (H_kv, nnz_kv, D)
@@ -330,7 +350,7 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward(
330350
qo_indptr = qo_indptr.to(torch::kInt32);
331351
kv_indptr = kv_indptr.to(torch::kInt32);
332352

333-
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
353+
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
334354
torch::Tensor o = torch::empty_like(q, q.options());
335355
torch::Tensor lse = torch::empty({0});
336356
if (return_lse) {
@@ -396,6 +416,13 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardC
396416
CHECK_INPUT(kv_indptr);
397417
CHECK_INPUT(custom_mask);
398418
CHECK_INPUT(qk_indptr);
419+
auto device = q.device();
420+
CHECK_EQ(device, qo_indptr.device());
421+
CHECK_EQ(device, k.device());
422+
CHECK_EQ(device, v.device());
423+
CHECK_EQ(device, kv_indptr.device());
424+
CHECK_EQ(device, custom_mask.device());
425+
CHECK_EQ(device, qk_indptr.device());
399426
CHECK_DIM(3, q); // (nnz_qo, H_qo, D)
400427
CHECK_DIM(1, qo_indptr); // (B + 1,)
401428
CHECK_DIM(3, k); // (nnz_kv, H_kv, D) if NHD else (H_kv, nnz_kv, D)
@@ -421,7 +448,7 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardC
421448
kv_indptr = kv_indptr.to(torch::kInt32);
422449
qk_indptr = qk_indptr.to(torch::kInt32);
423450

424-
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
451+
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
425452
torch::Tensor o = torch::empty_like(q, q.options());
426453
torch::Tensor lse = torch::empty({0});
427454
if (return_lse) {

Diff for: python/csrc/cascade.cu

+13-3
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ std::vector<torch::Tensor> merge_state(torch::Tensor v_a, torch::Tensor s_a, tor
2626
CHECK_INPUT(s_a);
2727
CHECK_INPUT(v_b);
2828
CHECK_INPUT(s_b);
29+
auto device = v_a.device();
30+
CHECK_EQ(s_a.device(), device);
31+
CHECK_EQ(v_b.device(), device);
32+
CHECK_EQ(s_b.device(), device);
2933
CHECK_DIM(3, v_a);
3034
CHECK_DIM(2, s_a);
3135
CHECK_DIM(3, v_b);
@@ -39,7 +43,7 @@ std::vector<torch::Tensor> merge_state(torch::Tensor v_a, torch::Tensor s_a, tor
3943
unsigned int seq_len = v_a.size(0);
4044
unsigned int num_heads = v_a.size(1);
4145
unsigned int head_dim = v_a.size(2);
42-
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
46+
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
4347
auto v_merged = torch::empty_like(v_a, v_a.options());
4448
auto s_merged = torch::empty({seq_len, num_heads}, s_a.options());
4549

@@ -64,6 +68,10 @@ void merge_state_in_place(torch::Tensor v, torch::Tensor s, torch::Tensor v_othe
6468
CHECK_INPUT(s);
6569
CHECK_INPUT(v_other);
6670
CHECK_INPUT(s_other);
71+
auto device = v.device();
72+
CHECK_EQ(s.device(), device);
73+
CHECK_EQ(v_other.device(), device);
74+
CHECK_EQ(s_other.device(), device);
6775
CHECK_DIM(3, v);
6876
CHECK_DIM(2, s);
6977
CHECK_DIM(3, v_other);
@@ -77,7 +85,7 @@ void merge_state_in_place(torch::Tensor v, torch::Tensor s, torch::Tensor v_othe
7785
unsigned int seq_len = v.size(0);
7886
unsigned int num_heads = v.size(1);
7987
unsigned int head_dim = v.size(2);
80-
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
88+
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
8189

8290
bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(v.scalar_type(), c_type, [&] {
8391
cudaError_t status = MergeStateInPlace(
@@ -95,6 +103,8 @@ void merge_state_in_place(torch::Tensor v, torch::Tensor s, torch::Tensor v_othe
95103
std::vector<torch::Tensor> merge_states(torch::Tensor v, torch::Tensor s) {
96104
CHECK_INPUT(v);
97105
CHECK_INPUT(s);
106+
auto device = v.device();
107+
CHECK_EQ(s.device(), device);
98108
CHECK_DIM(4, v);
99109
CHECK_DIM(3, s);
100110
CHECK_EQ(v.size(0), s.size(0));
@@ -105,7 +115,7 @@ std::vector<torch::Tensor> merge_states(torch::Tensor v, torch::Tensor s) {
105115
unsigned int num_heads = v.size(2);
106116
unsigned int head_dim = v.size(3);
107117
s = s.to(torch::kFloat32);
108-
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
118+
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
109119
auto v_merged = torch::empty({seq_len, num_heads, head_dim}, v.options());
110120
auto s_merged = torch::empty({seq_len, num_heads}, s.options());
111121

Diff for: python/csrc/group_gemm.cu

+9-5
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,12 @@ torch::Tensor CutlassSegmentGEMMPyTorchWrapper::Forward(torch::Tensor seg_indptr
3131
unsigned int batch_size,
3232
bool weight_column_major) {
3333
// TODO(Zihao): Add more checks here
34-
CHECK_CUDA(seg_indptr);
35-
CHECK_CUDA(x);
36-
CHECK_CUDA(weight);
34+
CHECK_INPUT(seg_indptr);
35+
CHECK_INPUT(x);
36+
CHECK_INPUT(weight);
37+
auto device = x.device();
38+
CHECK_EQ(seg_indptr.device(), device);
39+
CHECK_EQ(weight.device(), device);
3740
CHECK_DIM(2, x); // x: [sum(m_i), d_in]
3841
CHECK_DIM(3, weight); // weight: [num_weights, d_out, d_in] if weight_column_major, [num_weights,
3942
// d_in, d_out] otherwise
@@ -42,12 +45,13 @@ torch::Tensor CutlassSegmentGEMMPyTorchWrapper::Forward(torch::Tensor seg_indptr
4245
int64_t d_in = weight_column_major ? weight.size(2) : weight.size(1);
4346
CHECK_EQ(x.size(1), d_in);
4447
auto y = torch::zeros({cumulative_batch_size, d_out}, x.options());
45-
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
48+
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
4649
seg_indptr = seg_indptr.to(torch::kInt64);
4750

4851
bool weight_indices_defined = weight_indices.numel() > 0;
4952
if (weight_indices_defined) {
50-
CHECK_CUDA(weight_indices);
53+
CHECK_INPUT(weight_indices);
54+
CHECK_EQ(weight_indices.device(), device);
5155
weight_indices = weight_indices.to(torch::kInt64);
5256
}
5357

Diff for: python/csrc/norm.cu

+3-1
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,15 @@ using namespace flashinfer;
2323
torch::Tensor rmsnorm(torch::Tensor x, torch::Tensor w, double eps) {
2424
CHECK_INPUT(x);
2525
CHECK_INPUT(w);
26+
auto device = x.device();
27+
CHECK_EQ(w.device(), device);
2628
CHECK_DIM(2, x); // x: (batch_size, hidden_size)
2729
CHECK_DIM(1, w); // w: (hidden_size)
2830
CHECK_EQ(x.size(1), w.size(0));
2931
unsigned int batch_size = x.size(0);
3032
unsigned int hidden_size = x.size(1);
3133

32-
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
34+
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
3335
auto y = torch::empty_like(x);
3436
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(x.scalar_type(), c_type, [&] {
3537
cudaError_t status = norm::RMSNorm(

Diff for: python/csrc/page.cu

+8-1
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,13 @@ void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value,
4545
CHECK_EQ(kv_indptr.scalar_type(), torch::kInt32);
4646
CHECK_EQ(kv_indices.scalar_type(), torch::kInt32);
4747
CHECK_EQ(kv_last_page_len.scalar_type(), torch::kInt32);
48+
auto device = append_indptr.device();
49+
CHECK_EQ(append_key.device(), device);
50+
CHECK_EQ(append_value.device(), device);
51+
CHECK_EQ(kv_data.device(), device);
52+
CHECK_EQ(kv_indices.device(), device);
53+
CHECK_EQ(kv_indptr.device(), device);
54+
CHECK_EQ(kv_last_page_len.device(), device);
4855

4956
constexpr PageStorage page_storage = PageStorage::kIndices;
5057
QKVLayout kv_layout = QKVLayout(layout);
@@ -65,7 +72,7 @@ void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value,
6572
CHECK_EQ(append_value.size(1), num_heads);
6673
CHECK_EQ(append_key.size(2), head_dim);
6774

68-
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
75+
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
6976

7077
bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(kv_data.scalar_type(), c_type, [&] {
7178
DISPATCH_LAYOUT(kv_layout, KV_LAYOUT, {

Diff for: python/csrc/quantization.cu

+6-2
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@ using namespace flashinfer;
2222

2323
torch::Tensor packbits(torch::Tensor x, const std::string& bitorder) {
2424
CHECK_INPUT(x);
25+
auto device = x.device();
2526
TORCH_CHECK(bitorder == "big" || bitorder == "little", "bitorder must be 'big' or 'little'");
2627
x = x.to(torch::kBool);
27-
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
28+
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
2829

2930
int64_t num_elements = x.numel();
3031
int64_t num_output_elements = (num_elements + 7) / 8;
@@ -46,6 +47,9 @@ torch::Tensor segment_packbits(torch::Tensor x, torch::Tensor input_indptr,
4647
CHECK_INPUT(x);
4748
CHECK_INPUT(input_indptr);
4849
CHECK_INPUT(output_indptr);
50+
auto device = x.device();
51+
CHECK_EQ(input_indptr.device(), device);
52+
CHECK_EQ(output_indptr.device(), device);
4953
TORCH_CHECK(bitorder == "big" || bitorder == "little", "bitorder must be 'big' or 'little'");
5054
unsigned int batch_size = input_indptr.size(0) - 1;
5155
CHECK_EQ(output_indptr.size(0), batch_size + 1);
@@ -59,6 +63,6 @@ torch::Tensor segment_packbits(torch::Tensor x, torch::Tensor input_indptr,
5963
static_cast<int32_t*>(input_indptr.data_ptr()),
6064
static_cast<int32_t*>(output_indptr.data_ptr()), batch_size,
6165
bitorder == "big" ? quantization::BitOrder::kBig : quantization::BitOrder::kLittle,
62-
c10::cuda::getCurrentCUDAStream());
66+
c10::cuda::getCurrentCUDAStream(device.index()));
6367
return y;
6468
}

0 commit comments

Comments
 (0)