@@ -21,7 +21,7 @@ __device__ __forceinline__ int32_t index(int32_t total_col, int32_t row,
21
21
}
22
22
} // namespace
23
23
24
- template <typename scalar_t >
24
+ template <typename scalar_t , typename token_cnts_t >
25
25
__global__ void moe_align_block_size_kernel (scalar_t * __restrict__ topk_ids,
26
26
int32_t * sorted_token_ids,
27
27
int32_t * expert_ids,
@@ -32,12 +32,8 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
32
32
const size_t start_idx = threadIdx .x * tokens_per_thread;
33
33
34
34
extern __shared__ int32_t shared_mem[];
35
-
36
- int32_t * tokens_cnts =
37
- shared_mem; // 2d tensor with shape (blockDim.x + 1, num_experts)
38
- int32_t * cumsum =
39
- shared_mem +
40
- (blockDim .x + 1 ) * num_experts; // 1d tensor with shape (num_experts + 1)
35
+ int32_t * cumsum = shared_mem; // 1d tensor with shape (num_experts + 1)
36
+ token_cnts_t * tokens_cnts = (token_cnts_t *)(shared_mem + blockDim .x + 1 );
41
37
42
38
for (int i = 0 ; i < num_experts; ++i) {
43
39
tokens_cnts[index (num_experts, threadIdx .x + 1 , i)] = 0 ;
@@ -74,7 +70,7 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
74
70
block_size) *
75
71
block_size;
76
72
}
77
- *total_tokens_post_pad = cumsum[num_experts];
73
+ *total_tokens_post_pad = static_cast < int32_t >( cumsum[num_experts]) ;
78
74
}
79
75
80
76
__syncthreads ();
@@ -224,26 +220,44 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
224
220
torch::Tensor num_tokens_post_pad) {
225
221
const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
226
222
227
- // If we have very large number of experts, we can no longer use shared
228
- // memory.
229
- // TODO(simon): the right solution should be calculating the exact right
230
- // amount of shared memory and use that. The num_experts >= 256 is just a
231
- // temporary solution to unblock Deepseek V3.
232
- if (num_experts >= 256 ) {
223
+ int device_max_shared_mem;
224
+ auto dev = topk_ids.get_device ();
225
+ cudaDeviceGetAttribute (&device_max_shared_mem,
226
+ cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
227
+
228
+ const int32_t num_thread = max ((int32_t )num_experts, WARP_SIZE);
229
+ const int32_t shared_mem_i32 =
230
+ ((num_thread + 1 ) * num_experts + (num_experts + 1 )) * sizeof (int32_t );
231
+ const int32_t shared_mem_i16 =
232
+ ((num_thread + 1 ) * num_experts) * sizeof (uint16_t ) +
233
+ (num_experts + 1 ) * sizeof (int32_t );
234
+
235
+ bool use_global_memory = false ;
236
+ bool use_i16 = false ; // Use uint16_t for shared memory token counts
237
+ if (shared_mem_i16 > device_max_shared_mem) {
238
+ use_global_memory = true ;
239
+ } else if (shared_mem_i32 > device_max_shared_mem &&
240
+ topk_ids.numel () <= 65535 ) {
241
+ // when nelements of topk_ids is smaller than 65535 (max value of uint16),
242
+ // element value of token_cnts would also smaller than 65535,
243
+ // so we can use uint16 as dtype of token_cnts
244
+ use_i16 = true ;
245
+ }
246
+
247
+ if (use_global_memory) {
233
248
VLLM_DISPATCH_INTEGRAL_TYPES (
234
249
topk_ids.scalar_type (), " moe_align_block_size_global_mem_kernel" , [&] {
235
250
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
236
251
// tensors
237
252
const int32_t num_thread = max ((int32_t )num_experts, WARP_SIZE);
238
253
239
- const int32_t mem_tokens_cnts =
240
- ((num_experts + 1 ) * num_experts) * sizeof (int32_t );
241
- const int32_t mem_cumsum = (num_experts + 1 ) * sizeof (int32_t );
242
- // allocate global memory
243
- int32_t * tokens_cnts;
244
- int32_t * cumsum;
245
- cudaMalloc (&tokens_cnts, mem_tokens_cnts);
246
- cudaMalloc (&cumsum, mem_cumsum);
254
+ auto options_int = torch::TensorOptions ()
255
+ .dtype (torch::kInt )
256
+ .device (topk_ids.device ());
257
+ torch::Tensor token_cnts_buffer =
258
+ torch::empty ({(num_experts + 1 ) * num_experts}, options_int);
259
+ torch::Tensor cumsum_buffer =
260
+ torch::empty ({num_experts + 1 }, options_int);
247
261
248
262
auto kernel =
249
263
vllm::moe::moe_align_block_size_global_mem_kernel<scalar_t >;
@@ -252,25 +266,32 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
252
266
sorted_token_ids.data_ptr <int32_t >(),
253
267
experts_ids.data_ptr <int32_t >(),
254
268
num_tokens_post_pad.data_ptr <int32_t >(), num_experts, block_size,
255
- topk_ids.numel (), tokens_cnts, cumsum);
256
- cudaFree (tokens_cnts);
257
- cudaFree (cumsum);
269
+ topk_ids.numel (), token_cnts_buffer.data_ptr <int32_t >(),
270
+ cumsum_buffer.data_ptr <int32_t >());
258
271
});
259
- } else {
272
+ } else if (use_i16) {
260
273
VLLM_DISPATCH_INTEGRAL_TYPES (
261
274
topk_ids.scalar_type (), " moe_align_block_size_kernel" , [&] {
262
- // calc needed amount of shared mem for `tokens_cnts` and `cumsum`
263
- // tensors
264
- const int32_t num_thread = max ((int32_t )num_experts, WARP_SIZE);
265
- const int32_t shared_mem =
266
- ((num_thread + 1 ) * num_experts + (num_experts + 1 )) *
267
- sizeof (int32_t );
268
-
269
275
// set dynamic shared mem
270
- auto kernel = vllm::moe::moe_align_block_size_kernel<scalar_t >;
276
+ auto kernel =
277
+ vllm::moe::moe_align_block_size_kernel<scalar_t , uint16_t >;
278
+ AT_CUDA_CHECK (VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize (
279
+ (void *)kernel, shared_mem_i16));
280
+ kernel<<<1 , num_thread, shared_mem_i16, stream>>> (
281
+ topk_ids.data_ptr <scalar_t >(),
282
+ sorted_token_ids.data_ptr <int32_t >(),
283
+ experts_ids.data_ptr <int32_t >(),
284
+ num_tokens_post_pad.data_ptr <int32_t >(), num_experts, block_size,
285
+ topk_ids.numel ());
286
+ });
287
+ } else {
288
+ VLLM_DISPATCH_INTEGRAL_TYPES (
289
+ topk_ids.scalar_type (), " moe_align_block_size_kernel" , [&] {
290
+ auto kernel =
291
+ vllm::moe::moe_align_block_size_kernel<scalar_t , int32_t >;
271
292
AT_CUDA_CHECK (VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize (
272
- (void *)kernel, shared_mem ));
273
- kernel<<<1 , num_thread, shared_mem , stream>>> (
293
+ (void *)kernel, shared_mem_i32 ));
294
+ kernel<<<1 , num_thread, shared_mem_i32 , stream>>> (
274
295
topk_ids.data_ptr <scalar_t >(),
275
296
sorted_token_ids.data_ptr <int32_t >(),
276
297
experts_ids.data_ptr <int32_t >(),
0 commit comments