Skip to content

Commit f203673

Browse files
authored
[ModelOpt] Introduce VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE env var to control blockscale tensor allocation (#18160)
Signed-off-by: Pavani Majety <[email protected]>
1 parent 7d92164 commit f203673

File tree

3 files changed

+20
-18
lines changed

3 files changed

+20
-18
lines changed

vllm/_custom_ops.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1085,7 +1085,6 @@ def scaled_fp4_experts_quant(
10851085
blockscale_offsets: torch.Tensor,
10861086
topk: int,
10871087
expert_map: Optional[torch.Tensor] = None,
1088-
MAX_TOKENS_PER_EXPERT: int = 163840,
10891088
) -> tuple[torch.Tensor, torch.Tensor]:
10901089
"""
10911090
Quantize input tensor to FP4 and return quantized tensor and scale, for
@@ -1107,9 +1106,16 @@ def scaled_fp4_experts_quant(
11071106
input_tensor = input_tensor[
11081107
expert_map] if expert_map is not None else input_tensor
11091108
m_numtopk, k = input_tensor.shape
1109+
# Control the maximum number of tokens per expert supported by the
1110+
# NVFP4 MoE Expert Quantization. This is used to prevent the kernel
1111+
# from running out of memory. This value can also be increased to support
1112+
# larger models.
1113+
MAX_TOKENS_PER_EXPERT = envs.VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE
11101114
assert (m_numtopk <= MAX_TOKENS_PER_EXPERT * topk), (
1111-
f"m_numtopk must be less than MAX_TOKENS_PER_EXPERT * topk for"
1112-
f" scaled_fp4_experts_quant kernel, observed m_numtopk = {m_numtopk}")
1115+
f"m_numtopk must be less than MAX_TOKENS_PER_EXPERT("
1116+
f"{MAX_TOKENS_PER_EXPERT})"
1117+
f" for cutlass_moe_fp4, observed m_numtopk = {m_numtopk}. Use"
1118+
f" VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE to set this value.")
11131119
scales_k = k // 16
11141120
padded_k = (scales_k + (4 - 1)) // 4
11151121

vllm/envs.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@
117117
VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost"
118118
VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5557
119119
VLLM_ALL2ALL_BACKEND: str = "naive"
120+
VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE: int = 163840
120121

121122

122123
def get_default_cache_root():
@@ -814,6 +815,13 @@ def get_vllm_port() -> Optional[int]:
814815
# - "pplx": use pplx kernels
815816
"VLLM_ALL2ALL_BACKEND":
816817
lambda: os.getenv("VLLM_ALL2ALL_BACKEND", "naive"),
818+
819+
# Control the maximum number of tokens per expert supported by the
820+
# NVFP4 MoE CUTLASS Kernel. This value is used to create a buffer for
821+
# the blockscale tensor of activations NVFP4 Quantization.
822+
# This is used to prevent the kernel from running out of memory.
823+
"VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE":
824+
lambda: int(os.getenv("VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE", "163840")),
817825
}
818826

819827
# --8<-- [end:env-vars-definition]

vllm/model_executor/layers/fused_moe/cutlass_moe.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# SPDX-License-Identifier: Apache-2.0
22
""" CUTLASS based Fused MoE kernels."""
3-
import os
43
from typing import Optional
54

65
import torch
@@ -271,8 +270,6 @@ def cutlass_moe_fp8(
271270

272271
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
273272
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
274-
MAX_TOKENS_PER_EXPERT = int(
275-
os.environ.get('VLLM_MODELOPT_MAX_TOKENS_PER_EXPERT', '65536'))
276273

277274

278275
def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
@@ -330,10 +327,7 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
330327
assert a.dtype in [torch.half, torch.bfloat16], "Invalid input dtype"
331328
assert (topk_weights.shape[0] == m and topk_ids.shape[0]
332329
== m), ("topk must be provided for each row of a")
333-
assert (m <= MAX_TOKENS_PER_EXPERT), (
334-
f"m must be less than MAX_TOKENS_PER_EXPERT({MAX_TOKENS_PER_EXPERT})"
335-
f" for cutlass_moe_fp4, observed m = {m}. Use"
336-
f" VLLM_MODELOPT_MAX_TOKENS_PER_EXPERT to set this value.")
330+
337331
out_dtype = a.dtype
338332
num_topk = topk_ids.shape[1]
339333

@@ -362,8 +356,7 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
362356
expert_offsets,
363357
blockscale_offsets,
364358
num_topk,
365-
expert_map=a_map,
366-
MAX_TOKENS_PER_EXPERT=MAX_TOKENS_PER_EXPERT)
359+
expert_map=a_map)
367360

368361
c1 = ops.cutlass_fp4_moe_mm(rep_a_fp4, w1_fp4, rep_a_blockscale,
369362
w1_blockscale, w1_alphas, problem_sizes1,
@@ -378,12 +371,7 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
378371
torch.ops._C.silu_and_mul(intermediate, c1)
379372

380373
int_fp4, int_blockscale = ops.scaled_fp4_experts_quant(
381-
intermediate,
382-
a2_gscale,
383-
expert_offsets,
384-
blockscale_offsets,
385-
num_topk,
386-
MAX_TOKENS_PER_EXPERT=MAX_TOKENS_PER_EXPERT)
374+
intermediate, a2_gscale, expert_offsets, blockscale_offsets, num_topk)
387375

388376
c2 = ops.cutlass_fp4_moe_mm(int_fp4, w2_fp4, int_blockscale, w2_blockscale,
389377
w2_alphas, problem_sizes2, expert_offsets[:-1],

0 commit comments

Comments
 (0)