Skip to content

Commit 61e049a

Browse files
authored
perf: Fix python API overhead when CUDAGraph is not enabled (#969)
This PR fixes issue #960 , we identifies several performance bottlenecks for our python APIs when kernels are not captured by CUDAGraph: 1. The device guard in Python is slow (`with input.device as device:`) 2. Get current cuda stream in Python is time-consuming. These issues were introduced in JIT refactor after v0.1.6 (mainly for accelerating JIT compilation speed). In this PR, we changed back to get stream and device guard in C++). @MichoChan @xiaoqi35
1 parent f65b93f commit 61e049a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

59 files changed

+1218
-1278
lines changed

csrc/activation.cu

+10-7
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,12 @@ __device__ __forceinline__ float gelu_tanh(const float& val) {
3232
return val * cdf;
3333
}
3434

35-
void silu_and_mul(at::Tensor& out, at::Tensor& input, bool enable_pdl, int64_t cuda_stream) {
35+
void silu_and_mul(at::Tensor& out, at::Tensor& input, bool enable_pdl) {
3636
int d = input.size(-1) / 2;
3737
int64_t num_tokens = input.numel() / input.size(-1);
3838

39-
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
39+
const c10::cuda::OptionalCUDAGuard device_guard(out.device());
40+
auto stream = at::cuda::getCurrentCUDAStream();
4041

4142
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
4243
uint32_t vec_size = 16 / sizeof(c_type);
@@ -63,11 +64,13 @@ void silu_and_mul(at::Tensor& out, at::Tensor& input, bool enable_pdl, int64_t c
6364
});
6465
}
6566

66-
void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, bool enable_pdl, int64_t cuda_stream) {
67+
void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, bool enable_pdl) {
6768
int d = input.size(-1) / 2;
6869
int64_t num_tokens = input.numel() / input.size(-1);
6970

70-
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
71+
const c10::cuda::OptionalCUDAGuard device_guard(out.device());
72+
auto stream = at::cuda::getCurrentCUDAStream();
73+
7174
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
7275
uint32_t vec_size = 16 / sizeof(c_type);
7376
cudaLaunchConfig_t config;
@@ -93,12 +96,12 @@ void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, bool enable_pdl, int6
9396
});
9497
}
9598

96-
void gelu_and_mul(at::Tensor& out, at::Tensor& input, bool enable_pdl, int64_t cuda_stream) {
99+
void gelu_and_mul(at::Tensor& out, at::Tensor& input, bool enable_pdl) {
97100
int d = input.size(-1) / 2;
98101
int64_t num_tokens = input.numel() / input.size(-1);
99-
dim3 grid(num_tokens);
102+
const c10::cuda::OptionalCUDAGuard device_guard(out.device());
103+
auto stream = at::cuda::getCurrentCUDAStream();
100104

101-
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
102105
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
103106
uint32_t vec_size = 16 / sizeof(c_type);
104107
cudaLaunchConfig_t config;

csrc/batch_decode.cu

+13-9
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ at::Tensor BatchDecodeWithPagedKVCachePlan(
3838
at::Tensor page_locked_int_workspace_buffer, at::Tensor indptr, int64_t batch_size,
3939
int64_t num_qo_heads, int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph,
4040
int64_t window_left, double logits_soft_cap, int64_t head_dim_qk, int64_t head_dim_vo,
41-
at::Tensor empty_q_data, at::Tensor empty_kv_data, int64_t cuda_stream) {
41+
at::Tensor empty_q_data, at::Tensor empty_kv_data) {
4242
size_t float_workspace_size_in_bytes =
4343
float_workspace_buffer.size(0) * float_workspace_buffer.element_size();
4444
size_t int_workspace_size_in_bytes =
@@ -53,7 +53,8 @@ at::Tensor BatchDecodeWithPagedKVCachePlan(
5353
"CUDA cores template only supports equal head dim for QK and VO, please use tensor "
5454
"cores template for different head dim");
5555

56-
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
56+
const c10::cuda::OptionalCUDAGuard device_guard(float_workspace_buffer.device());
57+
const cudaStream_t stream = c10::cuda::getCurrentCUDAStream();
5758
DISPATCH_context(
5859
DTypeQ, DTypeKV, DTypeO, IdType, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE,
5960
USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, AttentionVariant, Params, [&] {
@@ -77,12 +78,14 @@ at::Tensor BatchDecodeWithPagedKVCachePlan(
7778
return vec_to_tensor(plan_info.ToVector());
7879
}
7980

80-
void BatchDecodeWithPagedKVCacheRun(
81-
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, at::Tensor plan_info_vec,
82-
at::Tensor q, at::Tensor paged_k_cache, at::Tensor paged_v_cache, at::Tensor paged_kv_indptr,
83-
at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, at::Tensor o,
84-
std::optional<at::Tensor> maybe_lse, int64_t kv_layout_code,
85-
int64_t window_left ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream) {
81+
void BatchDecodeWithPagedKVCacheRun(at::Tensor float_workspace_buffer,
82+
at::Tensor int_workspace_buffer, at::Tensor plan_info_vec,
83+
at::Tensor q, at::Tensor paged_k_cache,
84+
at::Tensor paged_v_cache, at::Tensor paged_kv_indptr,
85+
at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len,
86+
at::Tensor o, std::optional<at::Tensor> maybe_lse,
87+
int64_t kv_layout_code,
88+
int64_t window_left ADDITIONAL_FUNC_PARAMS) {
8689
DecodePlanInfo plan_info;
8790
plan_info.FromVector(tensor_to_vec(plan_info_vec));
8891
QKVLayout kv_layout = static_cast<QKVLayout>(kv_layout_code);
@@ -129,7 +132,8 @@ void BatchDecodeWithPagedKVCacheRun(
129132
TORCH_CHECK(k_strides == v_strides, "k/v strides must be identical");
130133
kv_cache_strides = k_strides.data();
131134

132-
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
135+
const c10::cuda::OptionalCUDAGuard device_guard(device);
136+
const cudaStream_t stream = c10::cuda::getCurrentCUDAStream();
133137

134138
DISPATCH_context(
135139
DTypeQ, DTypeKV, DTypeO, IdType, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE,

csrc/batch_decode_jit_pybind.cu

+9-7
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,16 @@ at::Tensor BatchDecodeWithPagedKVCachePlan(
2121
at::Tensor page_locked_int_workspace_buffer, at::Tensor indptr, int64_t batch_size,
2222
int64_t num_qo_heads, int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph,
2323
int64_t window_left, double logits_soft_cap, int64_t head_dim_qk, int64_t head_dim_vo,
24-
at::Tensor empty_q_data, at::Tensor empty_kv_data, int64_t cuda_stream);
24+
at::Tensor empty_q_data, at::Tensor empty_kv_data);
2525

26-
void BatchDecodeWithPagedKVCacheRun(
27-
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, at::Tensor plan_info_vec,
28-
at::Tensor q, at::Tensor paged_k_cache, at::Tensor paged_v_cache, at::Tensor paged_kv_indptr,
29-
at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, at::Tensor o,
30-
std::optional<at::Tensor> maybe_lse, int64_t kv_layout_code,
31-
int64_t window_left ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream);
26+
void BatchDecodeWithPagedKVCacheRun(at::Tensor float_workspace_buffer,
27+
at::Tensor int_workspace_buffer, at::Tensor plan_info_vec,
28+
at::Tensor q, at::Tensor paged_k_cache,
29+
at::Tensor paged_v_cache, at::Tensor paged_kv_indptr,
30+
at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len,
31+
at::Tensor o, std::optional<at::Tensor> maybe_lse,
32+
int64_t kv_layout_code,
33+
int64_t window_left ADDITIONAL_FUNC_PARAMS);
3234

3335
TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) {
3436
// Batched decode with paged KV-Cache plan

csrc/batch_mla_plan.cu

+4-3
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@ at::Tensor BatchMLAPagedAttentionPlan(at::Tensor float_workspace_buffer,
2626
at::Tensor int_workspace_buffer,
2727
at::Tensor page_locked_int_workspace_buffer,
2828
at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor kv_len,
29-
int64_t num_heads, int64_t head_dim_o, bool causal,
30-
int64_t cuda_stream) {
29+
int64_t num_heads, int64_t head_dim_o, bool causal) {
3130
size_t float_workspace_size_in_bytes =
3231
float_workspace_buffer.size(0) * float_workspace_buffer.element_size();
3332
size_t int_workspace_size_in_bytes =
@@ -37,7 +36,9 @@ at::Tensor BatchMLAPagedAttentionPlan(at::Tensor float_workspace_buffer,
3736

3837
int batch_size = kv_len.size(0);
3938

40-
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
39+
const c10::cuda::OptionalCUDAGuard device_guard(float_workspace_buffer.device());
40+
const cudaStream_t stream = c10::cuda::getCurrentCUDAStream();
41+
4142
cudaError_t status =
4243
MLAPlan(float_workspace_buffer.data_ptr(), float_workspace_size_in_bytes,
4344
int_workspace_buffer.data_ptr(), page_locked_int_workspace_buffer.data_ptr(),

csrc/batch_mla_pybind.cu

+2-3
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,14 @@ at::Tensor BatchMLAPagedAttentionPlan(at::Tensor float_workspace_buffer,
2020
at::Tensor int_workspace_buffer,
2121
at::Tensor page_locked_int_workspace_buffer,
2222
at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor kv_len,
23-
int64_t num_heads, int64_t head_dim_o, bool causal,
24-
int64_t cuda_stream);
23+
int64_t num_heads, int64_t head_dim_o, bool causal);
2524

2625
void BatchMLAPagedAttentionRun(at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
2726
at::Tensor plan_info_vec, at::Tensor q_nope, at::Tensor q_pe,
2827
at::Tensor ckv_cache, at::Tensor kpe_cache, at::Tensor kv_indices,
2928
at::Tensor o, std::optional<at::Tensor> maybe_lse,
3029
int64_t mask_mode_code, int64_t num_heads, int64_t page_size,
31-
double sm_scale, int64_t cuda_stream);
30+
double sm_scale);
3231

3332
TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) {
3433
m.def("plan", &BatchMLAPagedAttentionPlan);

csrc/batch_mla_run.cu

+3-2
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ void BatchMLAPagedAttentionRun(at::Tensor float_workspace_buffer, at::Tensor int
2929
at::Tensor ckv_cache, at::Tensor kpe_cache, at::Tensor kv_indices,
3030
at::Tensor o, std::optional<at::Tensor> maybe_lse,
3131
int64_t mask_mode_code, int64_t num_heads, int64_t page_size,
32-
double sm_scale, int64_t cuda_stream) {
32+
double sm_scale) {
3333
// q_nope: [n, num_heads, head_dim_ckv]
3434
// q_pe: [n, num_heads, head_dim_kpe]
3535
// ckv_cache: [num_pages, page_size, head_dim_ckv]
@@ -58,7 +58,8 @@ void BatchMLAPagedAttentionRun(at::Tensor float_workspace_buffer, at::Tensor int
5858
unsigned int o_stride_n = o.stride(0);
5959
unsigned int o_stride_h = o.stride(1);
6060

61-
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
61+
const c10::cuda::OptionalCUDAGuard device_guard(device);
62+
const cudaStream_t stream = c10::cuda::getCurrentCUDAStream();
6263

6364
DISPATCH_context(
6465
DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_CKV, HEAD_DIM_KPE, Params, [&] {

csrc/batch_mla_sm90_plan.cu

+4-2
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ at::Tensor BatchMLAPagedAttentionSM90Plan(at::Tensor float_workspace_buffer,
2727
at::Tensor page_locked_int_workspace_buffer,
2828
at::Tensor qo_indptr, at::Tensor kv_indptr,
2929
at::Tensor kv_len, int64_t num_heads, int64_t head_dim_o,
30-
bool causal, int64_t cuda_stream) {
30+
bool causal) {
3131
size_t float_workspace_size_in_bytes =
3232
float_workspace_buffer.size(0) * float_workspace_buffer.element_size();
3333
size_t int_workspace_size_in_bytes =
@@ -37,7 +37,9 @@ at::Tensor BatchMLAPagedAttentionSM90Plan(at::Tensor float_workspace_buffer,
3737

3838
int batch_size = kv_len.size(0);
3939

40-
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
40+
const c10::cuda::OptionalCUDAGuard device_guard(float_workspace_buffer.device());
41+
const cudaStream_t stream = c10::cuda::getCurrentCUDAStream();
42+
4143
cudaError_t status =
4244
MLAPlan(float_workspace_buffer.data_ptr(), float_workspace_size_in_bytes,
4345
int_workspace_buffer.data_ptr(), page_locked_int_workspace_buffer.data_ptr(),

csrc/batch_mla_sm90_pybind.cu

+2-2
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,15 @@ at::Tensor BatchMLAPagedAttentionSM90Plan(at::Tensor float_workspace_buffer,
2121
at::Tensor page_locked_int_workspace_buffer,
2222
at::Tensor qo_indptr, at::Tensor kv_indptr,
2323
at::Tensor kv_len, int64_t num_heads, int64_t head_dim_o,
24-
bool causal, int64_t cuda_stream);
24+
bool causal);
2525

2626
void BatchMLAPagedAttentionSM90Run(at::Tensor float_workspace_buffer,
2727
at::Tensor int_workspace_buffer, at::Tensor plan_info_vec,
2828
at::Tensor q_nope, at::Tensor q_pe, at::Tensor ckv_cache,
2929
at::Tensor kpe_cache, at::Tensor kv_indices, at::Tensor o,
3030
std::optional<at::Tensor> maybe_lse, int64_t mask_mode_code,
3131
int64_t num_heads, int64_t page_size,
32-
double sm_scale ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream);
32+
double sm_scale ADDITIONAL_FUNC_PARAMS);
3333

3434
TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) {
3535
m.def("plan", &BatchMLAPagedAttentionSM90Plan);

csrc/batch_mla_sm90_run.cu

+3-2
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ void BatchMLAPagedAttentionSM90Run(at::Tensor float_workspace_buffer,
3030
at::Tensor kpe_cache, at::Tensor kv_indices, at::Tensor o,
3131
std::optional<at::Tensor> maybe_lse, int64_t mask_mode_code,
3232
int64_t num_heads, int64_t page_size,
33-
double sm_scale ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream) {
33+
double sm_scale ADDITIONAL_FUNC_PARAMS) {
3434
// q_nope: [n, num_heads, head_dim_ckv]
3535
// q_pe: [n, num_heads, head_dim_kpe]
3636
// ckv_cache: [num_pages, page_size, head_dim_ckv]
@@ -59,7 +59,8 @@ void BatchMLAPagedAttentionSM90Run(at::Tensor float_workspace_buffer,
5959
unsigned int o_stride_n = o.stride(0);
6060
unsigned int o_stride_h = o.stride(1);
6161

62-
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
62+
const c10::cuda::OptionalCUDAGuard device_guard(device);
63+
const cudaStream_t stream = c10::cuda::getCurrentCUDAStream();
6364

6465
DISPATCH_context(
6566
DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_CKV, HEAD_DIM_KPE, Params, [&] {

csrc/batch_prefill.cu

+16-12
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,16 @@ at::Tensor BatchPrefillWithKVCachePlan(
4545
at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr,
4646
at::Tensor kv_len_arr, int64_t total_num_rows, int64_t batch_size, int64_t num_qo_heads,
4747
int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, int64_t head_dim_qk,
48-
int64_t head_dim_vo, bool causal, int64_t cuda_stream) {
48+
int64_t head_dim_vo, bool causal) {
4949
size_t float_workspace_size_in_bytes =
5050
float_workspace_buffer.size(0) * float_workspace_buffer.element_size();
5151
size_t int_workspace_size_in_bytes =
5252
int_workspace_buffer.size(0) * int_workspace_buffer.element_size();
5353

5454
PrefillPlanInfo plan_info;
5555

56-
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
56+
const c10::cuda::OptionalCUDAGuard device_guard(float_workspace_buffer.device());
57+
const cudaStream_t stream = c10::cuda::getCurrentCUDAStream();
5758
cudaError_t status = PrefillPlan<IdType>(
5859
float_workspace_buffer.data_ptr(), float_workspace_size_in_bytes,
5960
int_workspace_buffer.data_ptr(), page_locked_int_workspace_buffer.data_ptr(),
@@ -72,8 +73,7 @@ void BatchPrefillWithRaggedKVCacheRun(at::Tensor float_workspace_buffer,
7273
at::Tensor q, at::Tensor k, at::Tensor v,
7374
at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor o,
7475
std::optional<at::Tensor> maybe_lse, int64_t mask_mode_code,
75-
int64_t layout, int64_t window_left ADDITIONAL_FUNC_PARAMS,
76-
int64_t cuda_stream) {
76+
int64_t layout, int64_t window_left ADDITIONAL_FUNC_PARAMS) {
7777
PrefillPlanInfo plan_info;
7878
plan_info.FromVector(tensor_to_vec(plan_info_vec));
7979
QKVLayout kv_layout = static_cast<QKVLayout>(layout);
@@ -109,7 +109,8 @@ void BatchPrefillWithRaggedKVCacheRun(at::Tensor float_workspace_buffer,
109109
auto q_scalar_type = q.scalar_type();
110110
auto kv_scalar_type = k.scalar_type();
111111

112-
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
112+
const c10::cuda::OptionalCUDAGuard device_guard(float_workspace_buffer.device());
113+
const cudaStream_t stream = c10::cuda::getCurrentCUDAStream();
113114

114115
DISPATCH_context(
115116
DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE,
@@ -193,12 +194,14 @@ void BatchPrefillWithRaggedKVCacheRun(at::Tensor float_workspace_buffer,
193194
});
194195
}
195196

196-
void BatchPrefillWithPagedKVCacheRun(
197-
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, at::Tensor plan_info_vec,
198-
at::Tensor q, at::Tensor paged_k_cache, at::Tensor paged_v_cache, at::Tensor qo_indptr,
199-
at::Tensor paged_kv_indptr, at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len,
200-
at::Tensor o, std::optional<at::Tensor> maybe_lse, int64_t mask_mode_code, int64_t layout,
201-
int64_t window_left ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream) {
197+
void BatchPrefillWithPagedKVCacheRun(at::Tensor float_workspace_buffer,
198+
at::Tensor int_workspace_buffer, at::Tensor plan_info_vec,
199+
at::Tensor q, at::Tensor paged_k_cache,
200+
at::Tensor paged_v_cache, at::Tensor qo_indptr,
201+
at::Tensor paged_kv_indptr, at::Tensor paged_kv_indices,
202+
at::Tensor paged_kv_last_page_len, at::Tensor o,
203+
std::optional<at::Tensor> maybe_lse, int64_t mask_mode_code,
204+
int64_t layout, int64_t window_left ADDITIONAL_FUNC_PARAMS) {
202205
PrefillPlanInfo plan_info;
203206
plan_info.FromVector(tensor_to_vec(plan_info_vec));
204207
QKVLayout kv_layout = static_cast<QKVLayout>(layout);
@@ -239,7 +242,8 @@ void BatchPrefillWithPagedKVCacheRun(
239242
TORCH_CHECK(k_strides == v_strides, "k/v strides must be identical");
240243
kv_cache_strides = k_strides.data();
241244

242-
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
245+
const c10::cuda::OptionalCUDAGuard device_guard(float_workspace_buffer.device());
246+
const cudaStream_t stream = c10::cuda::getCurrentCUDAStream();
243247

244248
DISPATCH_context(
245249
DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE,

0 commit comments

Comments
 (0)