Skip to content

Commit 4c6ca52

Browse files
committed
Fix moe align block issue for mixtral
Signed-off-by: ElizaWszola <[email protected]>
1 parent ab5bbf5 commit 4c6ca52

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

csrc/moe/moe_align_sum_kernels.cu

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,17 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
3131
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
3232
const size_t start_idx = threadIdx.x * tokens_per_thread;
3333

34+
// compute aligned shared mem offset to make sure cumsum is aligned
35+
int cnts_byte_offset =
36+
((blockDim.x + 1) * num_experts) * sizeof(token_cnts_t);
37+
int aligned_offset =
38+
(cnts_byte_offset + sizeof(int32_t) - 1) / sizeof(int32_t);
39+
3440
extern __shared__ int32_t shared_mem[];
35-
int32_t* cumsum = shared_mem; // 1d tensor with shape (num_experts + 1)
36-
token_cnts_t* tokens_cnts = (token_cnts_t*)(shared_mem + blockDim.x + 1);
41+
token_cnts_t* tokens_cnts = (token_cnts_t*)
42+
shared_mem; // 2d tensor with shape (blockDim.x + 1, num_experts)
43+
int32_t* cumsum =
44+
shared_mem + aligned_offset; // 1d tensor with shape (num_experts + 1)
3745

3846
for (int i = 0; i < num_experts; ++i) {
3947
tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;

0 commit comments

Comments
 (0)