-
-
Notifications
You must be signed in to change notification settings - Fork 7.8k
[Model] Introduce CUDA Graph support for DeepSeek v3 #12204
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -221,29 +221,31 @@ __global__ void moe_sum_kernel( | |
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, | ||
int64_t block_size, torch::Tensor sorted_token_ids, | ||
torch::Tensor experts_ids, | ||
torch::Tensor num_tokens_post_pad) { | ||
torch::Tensor num_tokens_post_pad, | ||
bool use_shared_memory) { | ||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||
|
||
// If we have very large number of experts, we can no longer use shared | ||
// memory. | ||
// TODO(simon): the right solution should be calculating the exact right | ||
// amount of shared memory and use that. The num_experts >= 256 is just a | ||
// temporary solution to unblock Deepseek V3. | ||
if (num_experts >= 256) { | ||
if (!use_shared_memory) { | ||
VLLM_DISPATCH_INTEGRAL_TYPES( | ||
topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] { | ||
// calc needed amount of shared mem for `tokens_cnts` and `cumsum` | ||
// tensors | ||
const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE); | ||
|
||
const int32_t mem_tokens_cnts = | ||
((num_experts + 1) * num_experts) * sizeof(int32_t); | ||
const int32_t mem_cumsum = (num_experts + 1) * sizeof(int32_t); | ||
// allocate global memory | ||
int32_t* tokens_cnts; | ||
int32_t* cumsum; | ||
cudaMalloc(&tokens_cnts, mem_tokens_cnts); | ||
cudaMalloc(&cumsum, mem_cumsum); | ||
torch::Tensor token_cnts = | ||
torch::empty({(num_experts + 1) * num_experts}, | ||
torch::TensorOptions() | ||
.dtype(torch::kInt) | ||
.device(topk_ids.device())); | ||
torch::Tensor cumsum = | ||
torch::empty({num_experts + 1}, torch::TensorOptions() | ||
.dtype(torch::kInt) | ||
.device(topk_ids.device())); | ||
Comment on lines
-239
to
+248
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense to me, since during cuda graph capture, some actions, such as cudaMalloc, may be unsafe There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice doc pointer, thanks |
||
|
||
auto kernel = | ||
vllm::moe::moe_align_block_size_global_mem_kernel<scalar_t>; | ||
|
@@ -252,9 +254,8 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, | |
sorted_token_ids.data_ptr<int32_t>(), | ||
experts_ids.data_ptr<int32_t>(), | ||
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size, | ||
topk_ids.numel(), tokens_cnts, cumsum); | ||
cudaFree(tokens_cnts); | ||
cudaFree(cumsum); | ||
topk_ids.numel(), token_cnts.data_ptr<int32_t>(), | ||
cumsum.data_ptr<int32_t>()); | ||
}); | ||
} else { | ||
VLLM_DISPATCH_INTEGRAL_TYPES( | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why was this change needed for the fix BTW?