Skip to content

Commit 19d98e0

Browse files
authored
[Kernel] Optimize moe intermediate_cache usage (#13625)
Signed-off-by: mgoin <[email protected]>
1 parent 2b04c20 commit 19d98e0

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1240,15 +1240,20 @@ def fused_experts_impl(hidden_states: torch.Tensor,
12401240

12411241
config = get_config_func(M)
12421242

1243-
intermediate_cache1 = torch.empty((M, top_k_num, N),
1244-
device=hidden_states.device,
1245-
dtype=hidden_states.dtype)
1243+
# We can reuse the memory between these because by the time we need
1244+
# cache3, we're done with cache1
1245+
cache13 = torch.empty(M * top_k_num * max(N, w2.shape[1]),
1246+
device=hidden_states.device,
1247+
dtype=hidden_states.dtype)
1248+
intermediate_cache1 = cache13[:M * top_k_num * N].view(
1249+
(M, topk_ids.shape[1], N))
1250+
intermediate_cache3 = cache13[:M * top_k_num * w2.shape[1]].view(
1251+
(M, topk_ids.shape[1], w2.shape[1]))
1252+
1253+
# This needs separate memory since it's used concurrently with cache1
12461254
intermediate_cache2 = torch.empty((M * top_k_num, N // 2),
12471255
device=hidden_states.device,
12481256
dtype=hidden_states.dtype)
1249-
intermediate_cache3 = torch.empty((M, top_k_num, w2.shape[1]),
1250-
device=hidden_states.device,
1251-
dtype=hidden_states.dtype)
12521257

12531258
if hidden_states.dtype == torch.bfloat16:
12541259
compute_type = tl.bfloat16

0 commit comments

Comments
 (0)