Skip to content

Commit 04711d5

Browse files
mgoinMu Huai
authored and
Mu Huai
committed
[Kernel] Use moe_wna16 kernel for compressed tensors wna16 moe models (vllm-project#16038)
Signed-off-by: mgoin <[email protected]> Signed-off-by: Mu Huai <[email protected]>
1 parent 116fb0b commit 04711d5

File tree

5 files changed

+254
-15
lines changed

5 files changed

+254
-15
lines changed
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Qwen1.5-MoE-A2.7B-Chat-quantized.w4a16 -b auto -l 1319 -f 5 -t 1
2+
model_name: "nm-testing/Qwen1.5-MoE-A2.7B-Chat-quantized.w4a16"
3+
tasks:
4+
- name: "gsm8k"
5+
metrics:
6+
- name: "exact_match,strict-match"
7+
value: 0.31
8+
- name: "exact_match,flexible-extract"
9+
value: 0.47
10+
limit: 1319
11+
num_fewshot: 5

.buildkite/lm-eval-harness/configs/models-small.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ Meta-Llama-3.2-1B-Instruct-INT8-compressed-tensors.yaml
44
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors-asym.yaml
55
Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
66
Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml
7-
Minitron-4B-Base-FP8.yaml
7+
Qwen1.5-MoE-W4A16-compressed-tensors.yaml
88
Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml
99
Qwen2-1.5B-Instruct-FP8W8.yaml
1010
Meta-Llama-3-8B-QQQ.yaml

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -512,7 +512,9 @@ def __init__(
512512
}
513513
# need full intermediate size pre-sharding for WNA16 act order
514514
if (self.quant_method.__class__.__name__
515-
in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")):
515+
in ("GPTQMarlinMoEMethod",
516+
"CompressedTensorsWNA16MarlinMoEMethod",
517+
"CompressedTensorsWNA16MoEMethod")):
516518
moe_quant_params["intermediate_size_full"] = intermediate_size
517519

518520
self.quant_method.create_weights(layer=self, **moe_quant_params)
@@ -648,9 +650,10 @@ def weight_loader(self, param: torch.nn.Parameter,
648650
# compressed-tensors checkpoints with packed weights are stored flipped
649651
# TODO (mgoin): check self.quant_method.quant_config.quant_format
650652
# against known CompressionFormat enum values that have this quality
651-
loaded_weight = loaded_weight.t().contiguous() if (
652-
self.quant_method.__class__.__name__
653-
== "CompressedTensorsWNA16MoEMethod") else loaded_weight
653+
if self.quant_method.__class__.__name__ in (
654+
"CompressedTensorsWNA16MarlinMoEMethod",
655+
"CompressedTensorsWNA16MoEMethod"):
656+
loaded_weight = loaded_weight.t().contiguous()
654657

655658
if shard_id not in ("w1", "w2", "w3"):
656659
raise ValueError(f"shard_id must be ['w1','w2','w3'] but "

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,7 @@ def get_quant_method(
9696
if isinstance(layer, Attention):
9797
return CompressedTensorsKVCacheMethod(self)
9898
if isinstance(layer, FusedMoE):
99-
return CompressedTensorsMoEMethod.get_moe_method(
100-
self, layer.activation, layer.expert_map)
99+
return CompressedTensorsMoEMethod.get_moe_method(self, layer)
101100
return None
102101

103102
@classmethod

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py

Lines changed: 234 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66

77
import torch
88
from compressed_tensors import CompressionFormat
9-
from compressed_tensors.quantization import QuantizationStrategy
9+
from compressed_tensors.quantization import (ActivationOrdering,
10+
QuantizationStrategy)
1011

1112
import vllm.model_executor.layers.fused_moe # noqa
1213
from vllm import _custom_ops as ops
@@ -30,9 +31,11 @@ class GPTQMarlinState(Enum):
3031

3132

3233
__all__ = [
33-
"CompressedTensorsMoEMethod", "CompressedTensorsW8A8Fp8MoEMethod",
34+
"CompressedTensorsMoEMethod",
35+
"CompressedTensorsW8A8Fp8MoEMethod",
3436
"CompressedTensorsW8A8Fp8MoECutlassMethod",
35-
"CompressedTensorsWNA16MoEMethod"
37+
"CompressedTensorsWNA16MarlinMoEMethod",
38+
"CompressedTensorsWNA16MoEMethod",
3639
]
3740

3841

@@ -41,8 +44,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
4144
@staticmethod
4245
def get_moe_method(
4346
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
44-
activation: str,
45-
expert_map: Optional[torch.Tensor],
47+
layer: torch.nn.Module,
4648
) -> "CompressedTensorsMoEMethod":
4749
# TODO: @dsikka: refactor this to use schemes as other kernels
4850
# are supported + check if the layer is being ignored.
@@ -51,9 +53,21 @@ def get_moe_method(
5153
"input_activations")
5254

5355
if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
54-
return CompressedTensorsWNA16MoEMethod(quant_config)
56+
# Prefer to use the non-marlin kernel when:
57+
# 1. Many experts (MarlinMoE gives poor performance when >= 16)
58+
# 2. Non-FP16 dtype (MarlinMoE only supports FP16)
59+
# 3. Actorder is not group/dynamic (g_idx is unsupported)
60+
# 4. Scaled are grouped (channelwise is unsupported)
61+
if ((layer.local_num_experts >= 16
62+
or layer.params_dtype != torch.float16) and
63+
weight_quant.actorder not in (ActivationOrdering.GROUP,
64+
ActivationOrdering.DYNAMIC)
65+
and weight_quant.strategy in QuantizationStrategy.GROUP):
66+
return CompressedTensorsWNA16MoEMethod(quant_config)
67+
else:
68+
return CompressedTensorsWNA16MarlinMoEMethod(quant_config)
5569
elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant)
56-
and activation == "silu" and expert_map is None):
70+
and layer.activation == "silu" and layer.expert_map is None):
5771
return CompressedTensorsW8A8Fp8MoECutlassMethod(quant_config)
5872
elif quant_config._is_fp8_w8a8(weight_quant, input_quant):
5973
return CompressedTensorsW8A8Fp8MoEMethod(quant_config)
@@ -482,7 +496,7 @@ def apply(
482496
)
483497

484498

485-
class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
499+
class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
486500

487501
def __init__(
488502
self,
@@ -823,3 +837,215 @@ def apply(
823837
sort_indices2=layer.w2_g_idx_sort_indices,
824838
num_bits=self.num_bits,
825839
is_k_full=self.is_k_full)
840+
841+
842+
class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
843+
844+
def __init__(
845+
self,
846+
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
847+
):
848+
self.quant_config = quant_config
849+
# TODO: @dsikka: refactor this to use schemes as other kernels
850+
# are supported + check if the layer is being ignored.
851+
config = self.quant_config.target_scheme_map["Linear"].get("weights")
852+
self.num_bits = config.num_bits
853+
self.packed_factor = 32 // config.num_bits
854+
self.strategy = config.strategy
855+
# channelwise is not supported by this kernel
856+
assert config.strategy == "group"
857+
self.group_size = config.group_size
858+
# grouped actorder isn't supported by this kernel
859+
assert config.actorder != "group"
860+
assert config.symmetric, (
861+
"Only symmetric quantization is supported for MoE")
862+
863+
if not (self.quant_config.quant_format
864+
== CompressionFormat.pack_quantized.value
865+
and self.num_bits in WNA16_SUPPORTED_BITS):
866+
raise ValueError("For Fused MoE layers, only ",
867+
f"{CompressionFormat.pack_quantized.value} ",
868+
"is supported for the following bits: ",
869+
f"{WNA16_SUPPORTED_BITS}")
870+
871+
def create_weights(self, layer: torch.nn.Module, num_experts: int,
872+
hidden_size: int, intermediate_size_per_partition: int,
873+
params_dtype: torch.dtype, **extra_weight_attrs):
874+
875+
# Will transpose the loaded weight along the
876+
# intermediate and hidden dim sizes. Will
877+
# shard for TP along the transposed dims
878+
extra_weight_attrs.update({
879+
"is_transposed": True,
880+
"quant_method": self.strategy
881+
})
882+
w13_weight = torch.nn.Parameter(torch.empty(
883+
num_experts,
884+
hidden_size // self.packed_factor,
885+
2 * intermediate_size_per_partition,
886+
dtype=torch.int32),
887+
requires_grad=False)
888+
layer.register_parameter("w13_weight_packed", w13_weight)
889+
set_weight_attrs(w13_weight, extra_weight_attrs)
890+
891+
w2_weight = torch.nn.Parameter(torch.empty(
892+
num_experts,
893+
intermediate_size_per_partition // self.packed_factor,
894+
hidden_size,
895+
dtype=torch.int32),
896+
requires_grad=False)
897+
layer.register_parameter("w2_weight_packed", w2_weight)
898+
set_weight_attrs(w2_weight, extra_weight_attrs)
899+
900+
w2_scales_size = intermediate_size_per_partition
901+
902+
if self.strategy == "channel":
903+
num_groups_w2 = num_groups_w13 = 1
904+
self.group_size = -1
905+
else:
906+
num_groups_w2 = w2_scales_size // self.group_size
907+
num_groups_w13 = hidden_size // self.group_size
908+
909+
w13_scale = torch.nn.Parameter(torch.ones(
910+
num_experts,
911+
num_groups_w13,
912+
2 * intermediate_size_per_partition,
913+
dtype=params_dtype),
914+
requires_grad=False)
915+
layer.register_parameter("w13_weight_scale", w13_scale)
916+
set_weight_attrs(w13_scale, extra_weight_attrs)
917+
918+
w2_scale = torch.nn.Parameter(torch.ones(num_experts,
919+
num_groups_w2,
920+
hidden_size,
921+
dtype=params_dtype),
922+
requires_grad=False)
923+
layer.register_parameter("w2_weight_scale", w2_scale)
924+
set_weight_attrs(w2_scale, extra_weight_attrs)
925+
set_weight_attrs(w2_scale, {"load_full_w2": False})
926+
927+
w2_weight_shape = torch.nn.Parameter(torch.empty(num_experts, 2),
928+
requires_grad=False)
929+
layer.register_parameter("w2_weight_shape", w2_weight_shape)
930+
set_weight_attrs(w2_weight_shape, extra_weight_attrs)
931+
w13_weight_shape = torch.nn.Parameter(torch.empty(num_experts, 2),
932+
requires_grad=False)
933+
934+
layer.register_parameter("w13_weight_shape", w13_weight_shape)
935+
set_weight_attrs(w13_weight_shape, extra_weight_attrs)
936+
937+
w13_g_idx = torch.nn.Parameter(
938+
torch.empty(
939+
num_experts,
940+
hidden_size,
941+
dtype=torch.int32,
942+
),
943+
requires_grad=False,
944+
)
945+
layer.register_parameter("w13_weight_g_idx", w13_g_idx)
946+
set_weight_attrs(w13_g_idx, extra_weight_attrs)
947+
948+
w2_g_idx = torch.nn.Parameter(
949+
torch.empty(
950+
num_experts,
951+
intermediate_size_per_partition,
952+
dtype=torch.int32,
953+
),
954+
requires_grad=False,
955+
)
956+
layer.register_parameter("w2_weight_g_idx", w2_g_idx)
957+
set_weight_attrs(w2_g_idx, extra_weight_attrs)
958+
959+
w13_g_idx_sort_indices = torch.nn.Parameter(
960+
torch.empty(
961+
num_experts,
962+
hidden_size,
963+
dtype=torch.int32,
964+
),
965+
requires_grad=False,
966+
)
967+
layer.register_parameter("w13_g_idx_sort_indices",
968+
w13_g_idx_sort_indices)
969+
set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs)
970+
971+
w2_g_idx_sort_indices = torch.nn.Parameter(
972+
torch.empty(
973+
num_experts,
974+
intermediate_size_per_partition,
975+
dtype=torch.int32,
976+
),
977+
requires_grad=False,
978+
)
979+
layer.register_parameter("w2_g_idx_sort_indices",
980+
w2_g_idx_sort_indices)
981+
set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)
982+
983+
layer.a13_scale = None
984+
layer.a2_scale = None
985+
986+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
987+
# Reconfigure packed weights and scales to match moe_wna16 format
988+
layer.w13_weight_packed = torch.nn.Parameter(
989+
layer.w13_weight_packed.transpose(1, 2).contiguous().view(
990+
torch.uint8),
991+
requires_grad=False)
992+
layer.w2_weight_packed = torch.nn.Parameter(
993+
layer.w2_weight_packed.transpose(1,
994+
2).contiguous().view(torch.uint8),
995+
requires_grad=False)
996+
layer.w13_weight_scale = torch.nn.Parameter(
997+
layer.w13_weight_scale.transpose(1, 2).contiguous(),
998+
requires_grad=False)
999+
layer.w2_weight_scale = torch.nn.Parameter(
1000+
layer.w2_weight_scale.transpose(1, 2).contiguous(),
1001+
requires_grad=False)
1002+
1003+
def apply(
1004+
self,
1005+
layer: torch.nn.Module,
1006+
x: torch.Tensor,
1007+
router_logits: torch.Tensor,
1008+
top_k: int,
1009+
renormalize: bool,
1010+
use_grouped_topk: bool = False,
1011+
topk_group: Optional[int] = None,
1012+
num_expert_group: Optional[int] = None,
1013+
global_num_experts: int = -1,
1014+
expert_map: Optional[torch.Tensor] = None,
1015+
custom_routing_function: Optional[Callable] = None,
1016+
scoring_func: str = "softmax",
1017+
e_score_correction_bias: Optional[torch.Tensor] = None,
1018+
apply_router_weight_on_input: bool = False,
1019+
activation: str = "silu",
1020+
) -> torch.Tensor:
1021+
from vllm.model_executor.layers.fused_moe import fused_experts
1022+
assert activation == "silu", "Only SiLU activation is supported."
1023+
topk_weights, topk_ids = FusedMoE.select_experts(
1024+
hidden_states=x,
1025+
router_logits=router_logits,
1026+
use_grouped_topk=use_grouped_topk,
1027+
top_k=top_k,
1028+
renormalize=renormalize,
1029+
topk_group=topk_group,
1030+
num_expert_group=num_expert_group,
1031+
custom_routing_function=custom_routing_function,
1032+
scoring_func=scoring_func,
1033+
e_score_correction_bias=e_score_correction_bias)
1034+
1035+
return fused_experts(
1036+
x,
1037+
layer.w13_weight_packed,
1038+
layer.w2_weight_packed,
1039+
topk_weights=topk_weights,
1040+
topk_ids=topk_ids,
1041+
inplace=True,
1042+
use_int4_w4a16=self.num_bits == 4,
1043+
use_int8_w8a16=self.num_bits == 8,
1044+
global_num_experts=global_num_experts,
1045+
apply_router_weight_on_input=apply_router_weight_on_input,
1046+
expert_map=expert_map,
1047+
w1_scale=layer.w13_weight_scale,
1048+
w2_scale=layer.w2_weight_scale,
1049+
w1_zp=None,
1050+
w2_zp=None,
1051+
block_shape=[0, self.group_size])

0 commit comments

Comments
 (0)