Skip to content

[Model] Introduce CUDA Graph support for DeepSeek v3 #12204

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 14 additions & 13 deletions csrc/moe/moe_align_sum_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -221,29 +221,31 @@ __global__ void moe_sum_kernel(
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
int64_t block_size, torch::Tensor sorted_token_ids,
torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad) {
torch::Tensor num_tokens_post_pad,
bool use_shared_memory) {
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

// If we have very large number of experts, we can no longer use shared
// memory.
// TODO(simon): the right solution should be calculating the exact right
// amount of shared memory and use that. The num_experts >= 256 is just a
// temporary solution to unblock Deepseek V3.
if (num_experts >= 256) {
if (!use_shared_memory) {
Comment on lines -232 to +233
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was this change needed for the fix BTW?

VLLM_DISPATCH_INTEGRAL_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] {
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
// tensors
const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);

const int32_t mem_tokens_cnts =
((num_experts + 1) * num_experts) * sizeof(int32_t);
const int32_t mem_cumsum = (num_experts + 1) * sizeof(int32_t);
// allocate global memory
int32_t* tokens_cnts;
int32_t* cumsum;
cudaMalloc(&tokens_cnts, mem_tokens_cnts);
cudaMalloc(&cumsum, mem_cumsum);
torch::Tensor token_cnts =
torch::empty({(num_experts + 1) * num_experts},
torch::TensorOptions()
.dtype(torch::kInt)
.device(topk_ids.device()));
torch::Tensor cumsum =
torch::empty({num_experts + 1}, torch::TensorOptions()
.dtype(torch::kInt)
.device(topk_ids.device()));
Comment on lines -239 to +248
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice doc pointer, thanks


auto kernel =
vllm::moe::moe_align_block_size_global_mem_kernel<scalar_t>;
Expand All @@ -252,9 +254,8 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
sorted_token_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(),
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
topk_ids.numel(), tokens_cnts, cumsum);
cudaFree(tokens_cnts);
cudaFree(cumsum);
topk_ids.numel(), token_cnts.data_ptr<int32_t>(),
cumsum.data_ptr<int32_t>());
});
} else {
VLLM_DISPATCH_INTEGRAL_TYPES(
Expand Down
3 changes: 2 additions & 1 deletion csrc/moe/moe_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ void moe_sum(torch::Tensor& input, torch::Tensor& output);
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
int64_t block_size, torch::Tensor sorted_token_ids,
torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad);
torch::Tensor num_tokens_post_pad,
bool use_shared_memory);
3 changes: 2 additions & 1 deletion csrc/moe/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
"moe_align_block_size(Tensor topk_ids, int num_experts,"
" int block_size, Tensor! sorted_token_ids,"
" Tensor! experts_ids,"
" Tensor! num_tokens_post_pad) -> ()");
" Tensor! num_tokens_post_pad,"
" bool use_shared_memory) -> ()");
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);

#ifndef USE_ROCM
Expand Down
6 changes: 4 additions & 2 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,10 +917,12 @@ def moe_sum(input: torch.Tensor, output: torch.Tensor):
def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
block_size: int, sorted_token_ids: torch.Tensor,
experts_ids: torch.Tensor,
num_tokens_post_pad: torch.Tensor) -> None:
num_tokens_post_pad: torch.Tensor,
use_shared_memory: bool = True) -> None:
torch.ops._moe_C.moe_align_block_size(topk_ids, num_experts, block_size,
sorted_token_ids, experts_ids,
num_tokens_post_pad)
num_tokens_post_pad,
use_shared_memory)


def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
Expand Down
2 changes: 1 addition & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,7 @@ def _verify_cuda_graph(self) -> None:
self.max_seq_len_to_capture = min(self.max_seq_len_to_capture,
self.max_model_len)

MODEL_NOT_SUPPORT_CUDA_GRAPH = ['deepseek_v3', 'mllama']
MODEL_NOT_SUPPORT_CUDA_GRAPH = ['mllama']
if (self.hf_config.model_type in MODEL_NOT_SUPPORT_CUDA_GRAPH
and not self.enforce_eager):
logger.warning(
Expand Down
5 changes: 4 additions & 1 deletion vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,11 @@ def moe_align_block_size(
num_tokens_post_pad = torch.empty((1),
dtype=torch.int32,
device=topk_ids.device)
# Note: DeepSeek V3 has 256 experts.
use_shared_memory = num_experts < 256
ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
expert_ids, num_tokens_post_pad)
expert_ids, num_tokens_post_pad,
use_shared_memory)
return sorted_ids, expert_ids, num_tokens_post_pad


Expand Down
Loading