Skip to content

Commit 0b907c8

Browse files
jinzhen-linmgointlrmchlsmth
authored andcommitted
[Kernel] optimize moe_align_block_size for cuda graph and large num_experts (e.g. DeepSeek-V3) (vllm-project#12222)
Signed-off-by: Jinzhen Lin <[email protected]> Co-authored-by: Michael Goin <[email protected]> Co-authored-by: Tyler Michael Smith <[email protected]>
1 parent 377b72c commit 0b907c8

File tree

2 files changed

+58
-37
lines changed

2 files changed

+58
-37
lines changed

csrc/moe/moe_align_sum_kernels.cu

Lines changed: 57 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ __device__ __forceinline__ int32_t index(int32_t total_col, int32_t row,
2121
}
2222
} // namespace
2323

24-
template <typename scalar_t>
24+
template <typename scalar_t, typename token_cnts_t>
2525
__global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
2626
int32_t* sorted_token_ids,
2727
int32_t* expert_ids,
@@ -32,12 +32,8 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
3232
const size_t start_idx = threadIdx.x * tokens_per_thread;
3333

3434
extern __shared__ int32_t shared_mem[];
35-
36-
int32_t* tokens_cnts =
37-
shared_mem; // 2d tensor with shape (blockDim.x + 1, num_experts)
38-
int32_t* cumsum =
39-
shared_mem +
40-
(blockDim.x + 1) * num_experts; // 1d tensor with shape (num_experts + 1)
35+
int32_t* cumsum = shared_mem; // 1d tensor with shape (num_experts + 1)
36+
token_cnts_t* tokens_cnts = (token_cnts_t*)(shared_mem + blockDim.x + 1);
4137

4238
for (int i = 0; i < num_experts; ++i) {
4339
tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
@@ -74,7 +70,7 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
7470
block_size) *
7571
block_size;
7672
}
77-
*total_tokens_post_pad = cumsum[num_experts];
73+
*total_tokens_post_pad = static_cast<int32_t>(cumsum[num_experts]);
7874
}
7975

8076
__syncthreads();
@@ -224,26 +220,44 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
224220
torch::Tensor num_tokens_post_pad) {
225221
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
226222

227-
// If we have very large number of experts, we can no longer use shared
228-
// memory.
229-
// TODO(simon): the right solution should be calculating the exact right
230-
// amount of shared memory and use that. The num_experts >= 256 is just a
231-
// temporary solution to unblock Deepseek V3.
232-
if (num_experts >= 256) {
223+
int device_max_shared_mem;
224+
auto dev = topk_ids.get_device();
225+
cudaDeviceGetAttribute(&device_max_shared_mem,
226+
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
227+
228+
const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
229+
const int32_t shared_mem_i32 =
230+
((num_thread + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t);
231+
const int32_t shared_mem_i16 =
232+
((num_thread + 1) * num_experts) * sizeof(uint16_t) +
233+
(num_experts + 1) * sizeof(int32_t);
234+
235+
bool use_global_memory = false;
236+
bool use_i16 = false; // Use uint16_t for shared memory token counts
237+
if (shared_mem_i16 > device_max_shared_mem) {
238+
use_global_memory = true;
239+
} else if (shared_mem_i32 > device_max_shared_mem &&
240+
topk_ids.numel() <= 65535) {
241+
// when nelements of topk_ids is smaller than 65535 (max value of uint16),
242+
// element value of token_cnts would also smaller than 65535,
243+
// so we can use uint16 as dtype of token_cnts
244+
use_i16 = true;
245+
}
246+
247+
if (use_global_memory) {
233248
VLLM_DISPATCH_INTEGRAL_TYPES(
234249
topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] {
235250
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
236251
// tensors
237252
const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
238253

239-
const int32_t mem_tokens_cnts =
240-
((num_experts + 1) * num_experts) * sizeof(int32_t);
241-
const int32_t mem_cumsum = (num_experts + 1) * sizeof(int32_t);
242-
// allocate global memory
243-
int32_t* tokens_cnts;
244-
int32_t* cumsum;
245-
cudaMalloc(&tokens_cnts, mem_tokens_cnts);
246-
cudaMalloc(&cumsum, mem_cumsum);
254+
auto options_int = torch::TensorOptions()
255+
.dtype(torch::kInt)
256+
.device(topk_ids.device());
257+
torch::Tensor token_cnts_buffer =
258+
torch::empty({(num_experts + 1) * num_experts}, options_int);
259+
torch::Tensor cumsum_buffer =
260+
torch::empty({num_experts + 1}, options_int);
247261

248262
auto kernel =
249263
vllm::moe::moe_align_block_size_global_mem_kernel<scalar_t>;
@@ -252,25 +266,32 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
252266
sorted_token_ids.data_ptr<int32_t>(),
253267
experts_ids.data_ptr<int32_t>(),
254268
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
255-
topk_ids.numel(), tokens_cnts, cumsum);
256-
cudaFree(tokens_cnts);
257-
cudaFree(cumsum);
269+
topk_ids.numel(), token_cnts_buffer.data_ptr<int32_t>(),
270+
cumsum_buffer.data_ptr<int32_t>());
258271
});
259-
} else {
272+
} else if (use_i16) {
260273
VLLM_DISPATCH_INTEGRAL_TYPES(
261274
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
262-
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
263-
// tensors
264-
const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
265-
const int32_t shared_mem =
266-
((num_thread + 1) * num_experts + (num_experts + 1)) *
267-
sizeof(int32_t);
268-
269275
// set dynamic shared mem
270-
auto kernel = vllm::moe::moe_align_block_size_kernel<scalar_t>;
276+
auto kernel =
277+
vllm::moe::moe_align_block_size_kernel<scalar_t, uint16_t>;
278+
AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
279+
(void*)kernel, shared_mem_i16));
280+
kernel<<<1, num_thread, shared_mem_i16, stream>>>(
281+
topk_ids.data_ptr<scalar_t>(),
282+
sorted_token_ids.data_ptr<int32_t>(),
283+
experts_ids.data_ptr<int32_t>(),
284+
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
285+
topk_ids.numel());
286+
});
287+
} else {
288+
VLLM_DISPATCH_INTEGRAL_TYPES(
289+
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
290+
auto kernel =
291+
vllm::moe::moe_align_block_size_kernel<scalar_t, int32_t>;
271292
AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
272-
(void*)kernel, shared_mem));
273-
kernel<<<1, num_thread, shared_mem, stream>>>(
293+
(void*)kernel, shared_mem_i32));
294+
kernel<<<1, num_thread, shared_mem_i32, stream>>>(
274295
topk_ids.data_ptr<scalar_t>(),
275296
sorted_token_ids.data_ptr<int32_t>(),
276297
experts_ids.data_ptr<int32_t>(),

vllm/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -607,7 +607,7 @@ def _verify_cuda_graph(self) -> None:
607607
self.max_seq_len_to_capture = min(self.max_seq_len_to_capture,
608608
self.max_model_len)
609609

610-
MODEL_NOT_SUPPORT_CUDA_GRAPH = ['deepseek_v3', 'mllama']
610+
MODEL_NOT_SUPPORT_CUDA_GRAPH = ['mllama']
611611
if (self.hf_config.model_type in MODEL_NOT_SUPPORT_CUDA_GRAPH
612612
and not self.enforce_eager):
613613
logger.warning(

0 commit comments

Comments
 (0)