Skip to content

[Hardware/NVIDIA/Modelopt] Fix modelopt forward method for v1 torch.compile #18101

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions vllm/model_executor/layers/fused_moe/cutlass_moe.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
""" CUTLASS based Fused MoE kernels."""
import os
from typing import Optional

import torch
Expand Down Expand Up @@ -183,7 +184,8 @@ def cutlass_moe_fp8(

FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
MAX_TOKENS_PER_EXPERT = 65536
MAX_TOKENS_PER_EXPERT = int(
os.environ.get('VLLM_MODELOPT_MAX_TOKENS_PER_EXPERT', '65536'))


def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
Expand Down Expand Up @@ -243,7 +245,8 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
== m), ("topk must be provided for each row of a")
assert (m <= MAX_TOKENS_PER_EXPERT), (
f"m must be less than MAX_TOKENS_PER_EXPERT({MAX_TOKENS_PER_EXPERT})"
f" for cutlass_moe_fp4, observed m = {m}")
f" for cutlass_moe_fp4, observed m = {m}. Use"
f" VLLM_MODELOPT_MAX_TOKENS_PER_EXPERT to set this value.")
out_dtype = a.dtype
num_topk = topk_ids.shape[1]

Expand Down
17 changes: 9 additions & 8 deletions vllm/model_executor/layers/quantization/modelopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ def process_weights_after_loading(self, layer: Module) -> None:

layer.weight_scale_swizzled = Parameter(swizzled_weight_scale,
requires_grad=False)
layer.weight = Parameter(layer.weight.data, requires_grad=False)

if self.use_marlin:
prepare_fp4_layer_for_marlin(layer)
Expand All @@ -426,11 +427,7 @@ def apply(
bias=bias)

output_dtype = x.dtype

# for input only the contracting dimension has a constraint.
x_m, _ = x.shape
w_n, _ = layer.weight.shape
output_shape = [x_m, w_n]
output_shape = [x.shape[0], layer.weight.shape[0]]

# quantize BF16 or FP16 to (FP4 and interleaved block scale)
s_quant = 1 / layer.input_scale
Expand Down Expand Up @@ -586,11 +583,11 @@ def swizzle_blockscale(self, scale: torch.tensor):
if scale_ndim == 2 else swizzled_scale.reshape(B, M, K))

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# GEMM 1

# GEMM 1
assert torch.allclose(
layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]), (
"Expected w1_weight_scale_2 to equal w3_weight_scale_2")
"w1_weight_scale_2 must match w3_weight_scale_2")

w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2,
Expand All @@ -616,6 +613,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.w13_input_scale_quant = Parameter(
(1 / w13_input_scale).to(torch.float32), requires_grad=False)

layer.w13_weight = Parameter(layer.w13_weight.data,
requires_grad=False)

# GEMM 2
layer.g2_alphas = Parameter(
(layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
Expand All @@ -633,6 +633,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:

layer.w2_blockscale_swizzled = Parameter(w2_blockscale_swizzled,
requires_grad=False)
layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)

if self.use_marlin:
prepare_moe_fp4_layer_for_marlin(layer)
Expand Down Expand Up @@ -694,7 +695,7 @@ def apply(
assert not apply_router_weight_on_input, (
"Router weight on input is not "
"supported for ModelOptNvFp4FusedMoE.")
assert expert_map is None, ("Expert Parallelism /expert_map "
assert expert_map is None, ("Expert Parallelism / expert_map "
"is currently not supported for "
"ModelOptNvFp4FusedMoE.")

Expand Down