Skip to content

[Misc] Update w2 scale loading for GPTQMarlinMoE #12757

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tests/weight_loading/models-large.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W8A16-quantized, main
compressed-tensors, nm-testing/test-w4a16-mixtral-actorder-group, main
gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main
gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, gptq-8bit-128g-actorder_True
awq_marlin, casperhansen/deepseek-coder-v2-instruct-awq, main
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,8 @@ def __init__(
"weight_loader": self.weight_loader,
}
# need full intermediate size pre-sharding for WNA16 act order
if (self.quant_method.__class__.__name__ ==
"CompressedTensorsWNA16MoEMethod"):
if (self.quant_method.__class__.__name__
in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")):
moe_quant_params["intermediate_size_full"] = intermediate_size

self.quant_method.create_weights(layer=self, **moe_quant_params)
Expand Down
23 changes: 17 additions & 6 deletions vllm/model_executor/layers/quantization/gptq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,13 +323,18 @@ def create_weights(
params_dtype: torch.dtype,
**extra_weight_attrs,
):
# Currently assuming is_k_full is always True
# (input size per partition is the same as full input size)
# Supports only sym for now (no zp)
intermediate_size_full = extra_weight_attrs.pop(
"intermediate_size_full")

self.is_k_full = (not self.quant_config.desc_act) or (
intermediate_size_per_partition == intermediate_size_full)

if self.quant_config.group_size != -1:
scales_size13 = hidden_size // self.quant_config.group_size
scales_size2 = (intermediate_size_per_partition //
self.quant_config.group_size)
w2_scales_size = (intermediate_size_full
if self.quant_config.desc_act else
intermediate_size_per_partition)
scales_size2 = (w2_scales_size // self.quant_config.group_size)
strategy = FusedMoeWeightScaleSupported.GROUP.value
else:
scales_size13 = 1
Expand Down Expand Up @@ -385,6 +390,9 @@ def create_weights(
)
layer.register_parameter("w2_scales", w2_scales)
set_weight_attrs(w2_scales, extra_weight_attrs)
# dont shard the w2 scales when running act order
set_weight_attrs(w2_scales,
{"load_full_w2": self.quant_config.desc_act})
# up_proj scales
w13_qzeros = torch.nn.Parameter(
torch.empty(num_experts,
Expand All @@ -406,6 +414,9 @@ def create_weights(
)
layer.register_parameter("w2_qzeros", w2_qzeros)
set_weight_attrs(w2_qzeros, extra_weight_attrs)
# dont shard the w2 scales when running act order
set_weight_attrs(w2_qzeros,
{"load_full_w2": self.quant_config.desc_act})
w13_g_idx = torch.nn.Parameter(
torch.empty(
num_experts,
Expand Down Expand Up @@ -575,4 +586,4 @@ def apply(
sort_indices1=layer.w13_g_idx_sort_indices,
sort_indices2=layer.w2_g_idx_sort_indices,
num_bits=self.quant_config.quant_type.size_bits,
).to(orig_dtype)
is_k_full=self.is_k_full).to(orig_dtype)