Skip to content

Commit 7ca9934

Browse files
authored
[Misc] Update w2 scale loading for GPTQMarlinMoE (#12757)
1 parent 0408efc commit 7ca9934

File tree

3 files changed

+21
-8
lines changed

3 files changed

+21
-8
lines changed

tests/weight_loading/models-large.txt

+2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main
22
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main
33
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W8A16-quantized, main
4+
compressed-tensors, nm-testing/test-w4a16-mixtral-actorder-group, main
45
gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main
6+
gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, gptq-8bit-128g-actorder_True
57
awq_marlin, casperhansen/deepseek-coder-v2-instruct-awq, main

vllm/model_executor/layers/fused_moe/layer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -302,8 +302,8 @@ def __init__(
302302
"weight_loader": self.weight_loader,
303303
}
304304
# need full intermediate size pre-sharding for WNA16 act order
305-
if (self.quant_method.__class__.__name__ ==
306-
"CompressedTensorsWNA16MoEMethod"):
305+
if (self.quant_method.__class__.__name__
306+
in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")):
307307
moe_quant_params["intermediate_size_full"] = intermediate_size
308308

309309
self.quant_method.create_weights(layer=self, **moe_quant_params)

vllm/model_executor/layers/quantization/gptq_marlin.py

+17-6
Original file line numberDiff line numberDiff line change
@@ -323,13 +323,18 @@ def create_weights(
323323
params_dtype: torch.dtype,
324324
**extra_weight_attrs,
325325
):
326-
# Currently assuming is_k_full is always True
327-
# (input size per partition is the same as full input size)
328-
# Supports only sym for now (no zp)
326+
intermediate_size_full = extra_weight_attrs.pop(
327+
"intermediate_size_full")
328+
329+
self.is_k_full = (not self.quant_config.desc_act) or (
330+
intermediate_size_per_partition == intermediate_size_full)
331+
329332
if self.quant_config.group_size != -1:
330333
scales_size13 = hidden_size // self.quant_config.group_size
331-
scales_size2 = (intermediate_size_per_partition //
332-
self.quant_config.group_size)
334+
w2_scales_size = (intermediate_size_full
335+
if self.quant_config.desc_act else
336+
intermediate_size_per_partition)
337+
scales_size2 = (w2_scales_size // self.quant_config.group_size)
333338
strategy = FusedMoeWeightScaleSupported.GROUP.value
334339
else:
335340
scales_size13 = 1
@@ -385,6 +390,9 @@ def create_weights(
385390
)
386391
layer.register_parameter("w2_scales", w2_scales)
387392
set_weight_attrs(w2_scales, extra_weight_attrs)
393+
# dont shard the w2 scales when running act order
394+
set_weight_attrs(w2_scales,
395+
{"load_full_w2": self.quant_config.desc_act})
388396
# up_proj scales
389397
w13_qzeros = torch.nn.Parameter(
390398
torch.empty(num_experts,
@@ -406,6 +414,9 @@ def create_weights(
406414
)
407415
layer.register_parameter("w2_qzeros", w2_qzeros)
408416
set_weight_attrs(w2_qzeros, extra_weight_attrs)
417+
# dont shard the w2 scales when running act order
418+
set_weight_attrs(w2_qzeros,
419+
{"load_full_w2": self.quant_config.desc_act})
409420
w13_g_idx = torch.nn.Parameter(
410421
torch.empty(
411422
num_experts,
@@ -575,4 +586,4 @@ def apply(
575586
sort_indices1=layer.w13_g_idx_sort_indices,
576587
sort_indices2=layer.w2_g_idx_sort_indices,
577588
num_bits=self.quant_config.quant_type.size_bits,
578-
).to(orig_dtype)
589+
is_k_full=self.is_k_full).to(orig_dtype)

0 commit comments

Comments
 (0)