diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 1b34e952208..7f96a401271 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 """ CUTLASS based Fused MoE kernels.""" +import os from typing import Optional import torch @@ -183,7 +184,8 @@ def cutlass_moe_fp8( FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max() FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max -MAX_TOKENS_PER_EXPERT = 65536 +MAX_TOKENS_PER_EXPERT = int( + os.environ.get('VLLM_MODELOPT_MAX_TOKENS_PER_EXPERT', '65536')) def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor, @@ -243,7 +245,8 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor, == m), ("topk must be provided for each row of a") assert (m <= MAX_TOKENS_PER_EXPERT), ( f"m must be less than MAX_TOKENS_PER_EXPERT({MAX_TOKENS_PER_EXPERT})" - f" for cutlass_moe_fp4, observed m = {m}") + f" for cutlass_moe_fp4, observed m = {m}. Use" + f" VLLM_MODELOPT_MAX_TOKENS_PER_EXPERT to set this value.") out_dtype = a.dtype num_topk = topk_ids.shape[1] diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index bd9daa7c608..50493ccf2b7 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -401,6 +401,7 @@ def process_weights_after_loading(self, layer: Module) -> None: layer.weight_scale_swizzled = Parameter(swizzled_weight_scale, requires_grad=False) + layer.weight = Parameter(layer.weight.data, requires_grad=False) if self.use_marlin: prepare_fp4_layer_for_marlin(layer) @@ -426,11 +427,7 @@ def apply( bias=bias) output_dtype = x.dtype - - # for input only the contracting dimension has a constraint. - x_m, _ = x.shape - w_n, _ = layer.weight.shape - output_shape = [x_m, w_n] + output_shape = [x.shape[0], layer.weight.shape[0]] # quantize BF16 or FP16 to (FP4 and interleaved block scale) s_quant = 1 / layer.input_scale @@ -586,11 +583,11 @@ def swizzle_blockscale(self, scale: torch.tensor): if scale_ndim == 2 else swizzled_scale.reshape(B, M, K)) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - # GEMM 1 + # GEMM 1 assert torch.allclose( layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]), ( - "Expected w1_weight_scale_2 to equal w3_weight_scale_2") + "w1_weight_scale_2 must match w3_weight_scale_2") w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0] layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, @@ -616,6 +613,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.w13_input_scale_quant = Parameter( (1 / w13_input_scale).to(torch.float32), requires_grad=False) + layer.w13_weight = Parameter(layer.w13_weight.data, + requires_grad=False) + # GEMM 2 layer.g2_alphas = Parameter( (layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32), @@ -633,6 +633,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.w2_blockscale_swizzled = Parameter(w2_blockscale_swizzled, requires_grad=False) + layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False) if self.use_marlin: prepare_moe_fp4_layer_for_marlin(layer) @@ -694,7 +695,7 @@ def apply( assert not apply_router_weight_on_input, ( "Router weight on input is not " "supported for ModelOptNvFp4FusedMoE.") - assert expert_map is None, ("Expert Parallelism /expert_map " + assert expert_map is None, ("Expert Parallelism / expert_map " "is currently not supported for " "ModelOptNvFp4FusedMoE.")