Skip to content

Commit e4564cb

Browse files
jinzhen-lintjtanaa
authored andcommitted
[Kernel] fix moe_align_block_size error condition (vllm-project#12239)
Signed-off-by: Jinzhen Lin <[email protected]>
1 parent 98b8414 commit e4564cb

File tree

3 files changed

+13
-11
lines changed

3 files changed

+13
-11
lines changed

csrc/moe/moe_align_sum_kernels.cu

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
3333

3434
extern __shared__ int32_t shared_mem[];
3535
int32_t* cumsum = shared_mem; // 1d tensor with shape (num_experts + 1)
36-
token_cnts_t* tokens_cnts = (token_cnts_t*)(shared_mem + blockDim.x + 1);
36+
token_cnts_t* tokens_cnts =
37+
(token_cnts_t*)(shared_mem + num_experts +
38+
1); // 2d tensor with shape (blockDim.x + 1, num_experts)
3739

3840
for (int i = 0; i < num_experts; ++i) {
3941
tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
@@ -234,14 +236,16 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
234236

235237
bool use_global_memory = false;
236238
bool use_i16 = false; // Use uint16_t for shared memory token counts
237-
if (shared_mem_i16 > device_max_shared_mem) {
238-
use_global_memory = true;
239-
} else if (shared_mem_i32 > device_max_shared_mem &&
239+
if (shared_mem_i32 < device_max_shared_mem) {
240+
// Do nothing in this case. We're all set to use int32_t token counts
241+
} else if (shared_mem_i16 < device_max_shared_mem &&
240242
topk_ids.numel() <= 65535) {
241243
// when nelements of topk_ids is smaller than 65535 (max value of uint16),
242244
// element value of token_cnts would also smaller than 65535,
243245
// so we can use uint16 as dtype of token_cnts
244246
use_i16 = true;
247+
} else {
248+
use_global_memory = true;
245249
}
246250

247251
if (use_global_memory) {
@@ -342,4 +346,4 @@ void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size]
342346
at::sum_out(output, input, 1);
343347
break;
344348
}
345-
}
349+
}

vllm/model_executor/models/paligemma.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
147147
"up_proj",
148148
],
149149
}
150-
150+
151151
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
152152
super().__init__()
153153
config = vllm_config.model_config.hf_config

vllm/model_executor/models/siglip.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -348,12 +348,10 @@ def __init__(
348348
if quant_config and quant_config.get_name() == "bitsandbytes":
349349
quantizable = True
350350
else:
351-
# For other quantization, we require the hidden size to be a
351+
# For other quantization, we require the hidden size to be a
352352
# multiple of 64
353-
quantizable = (
354-
config.hidden_size % 64 == 0
355-
and config.intermediate_size % 64 == 0
356-
)
353+
quantizable = (config.hidden_size % 64 == 0
354+
and config.intermediate_size % 64 == 0)
357355
self.fc1 = ColumnParallelLinear(
358356
config.hidden_size,
359357
config.intermediate_size,

0 commit comments

Comments
 (0)