Skip to content

Commit ccb3e86

Browse files
jinzhen-linIsotr0py
authored andcommitted
[Kernel] fix moe_align_block_size error condition (vllm-project#12239)
Signed-off-by: Jinzhen Lin <[email protected]> Signed-off-by: Isotr0py <[email protected]>
1 parent a0246ff commit ccb3e86

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

csrc/moe/moe_align_sum_kernels.cu

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -233,15 +233,17 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
233233
(num_experts + 1) * sizeof(int32_t);
234234

235235
bool use_global_memory = false;
236-
bool use_i16 = false; // Use uint16_t for shared memory token counts
237-
if (shared_mem_i16 > device_max_shared_mem) {
238-
use_global_memory = true;
239-
} else if (shared_mem_i32 > device_max_shared_mem &&
236+
bool use_i16 = false; // Use uint16_t for shared memory token counts
237+
if (shared_mem_i32 < device_max_shared_mem) {
238+
// Do nothing in this case. We're all set to use int32_t token counts
239+
} else if (shared_mem_i16 < device_max_shared_mem &&
240240
topk_ids.numel() <= 65535) {
241241
// when nelements of topk_ids is smaller than 65535 (max value of uint16),
242242
// element value of token_cnts would also smaller than 65535,
243243
// so we can use uint16 as dtype of token_cnts
244244
use_i16 = true;
245+
} else {
246+
use_global_memory = true;
245247
}
246248

247249
if (use_global_memory) {

0 commit comments

Comments
 (0)