Skip to content

Commit ccc1cf5

Browse files
committed
Merge branch 'compressed-tensors-moe-wna16' of https://github.com/neuralmagic/vllm
2 parents 566f10a + 73f7af5 commit ccc1cf5

File tree

3 files changed

+240
-13
lines changed

3 files changed

+240
-13
lines changed

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,7 @@ def __init__(
473473

474474
assert intermediate_size % self.tp_size == 0
475475
self.intermediate_size_per_partition = intermediate_size // self.tp_size
476+
self.params_dtype = params_dtype
476477
self.reduce_results = reduce_results
477478
self.renormalize = renormalize
478479
self.use_grouped_topk = use_grouped_topk
@@ -512,7 +513,9 @@ def __init__(
512513
}
513514
# need full intermediate size pre-sharding for WNA16 act order
514515
if (self.quant_method.__class__.__name__
515-
in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")):
516+
in ("GPTQMarlinMoEMethod",
517+
"CompressedTensorsWNA16MarlinMoEMethod",
518+
"CompressedTensorsWNA16MoEMethod")):
516519
moe_quant_params["intermediate_size_full"] = intermediate_size
517520

518521
self.quant_method.create_weights(layer=self, **moe_quant_params)
@@ -648,9 +651,10 @@ def weight_loader(self, param: torch.nn.Parameter,
648651
# compressed-tensors checkpoints with packed weights are stored flipped
649652
# TODO (mgoin): check self.quant_method.quant_config.quant_format
650653
# against known CompressionFormat enum values that have this quality
651-
loaded_weight = loaded_weight.t().contiguous() if (
652-
self.quant_method.__class__.__name__
653-
== "CompressedTensorsWNA16MoEMethod") else loaded_weight
654+
if self.quant_method.__class__.__name__ in (
655+
"CompressedTensorsWNA16MarlinMoEMethod",
656+
"CompressedTensorsWNA16MoEMethod"):
657+
loaded_weight = loaded_weight.t().contiguous()
654658

655659
if shard_id not in ("w1", "w2", "w3"):
656660
raise ValueError(f"shard_id must be ['w1','w2','w3'] but "

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,7 @@ def get_quant_method(
9696
if isinstance(layer, Attention):
9797
return CompressedTensorsKVCacheMethod(self)
9898
if isinstance(layer, FusedMoE):
99-
return CompressedTensorsMoEMethod.get_moe_method(
100-
self, layer.activation, layer.expert_map)
99+
return CompressedTensorsMoEMethod.get_moe_method(self, layer)
101100
return None
102101

103102
@classmethod

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py

Lines changed: 231 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,11 @@ class GPTQMarlinState(Enum):
3030

3131

3232
__all__ = [
33-
"CompressedTensorsMoEMethod", "CompressedTensorsW8A8Fp8MoEMethod",
33+
"CompressedTensorsMoEMethod",
34+
"CompressedTensorsW8A8Fp8MoEMethod",
3435
"CompressedTensorsW8A8Fp8MoECutlassMethod",
35-
"CompressedTensorsWNA16MoEMethod"
36+
"CompressedTensorsWNA16MarlinMoEMethod",
37+
"CompressedTensorsWNA16MoEMethod",
3638
]
3739

3840

@@ -41,8 +43,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
4143
@staticmethod
4244
def get_moe_method(
4345
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
44-
activation: str,
45-
expert_map: Optional[torch.Tensor],
46+
layer: torch.nn.Module,
4647
) -> "CompressedTensorsMoEMethod":
4748
# TODO: @dsikka: refactor this to use schemes as other kernels
4849
# are supported + check if the layer is being ignored.
@@ -51,9 +52,20 @@ def get_moe_method(
5152
"input_activations")
5253

5354
if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
54-
return CompressedTensorsWNA16MoEMethod(quant_config)
55+
# Prefer to use the non-marlin kernel when:
56+
# 1. Many experts (MarlinMoE gives poor performance when >= 16)
57+
# 2. Non-FP16 dtype (MarlinMoE only supports FP16)
58+
# 3. Actorder is not dynamic (g_idx is unsupported)
59+
# 4. Scaled are grouped (channelwise is unsupported)
60+
if ((layer.local_num_experts >= 16
61+
or layer.params_dtype != torch.float16)
62+
and weight_quant.actorder != "group"
63+
and weight_quant.strategy == "group"):
64+
return CompressedTensorsWNA16MoEMethod(quant_config)
65+
else:
66+
return CompressedTensorsWNA16MarlinMoEMethod(quant_config)
5567
elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant)
56-
and activation == "silu" and expert_map is None):
68+
and layer.activation == "silu" and layer.expert_map is None):
5769
return CompressedTensorsW8A8Fp8MoECutlassMethod(quant_config)
5870
elif quant_config._is_fp8_w8a8(weight_quant, input_quant):
5971
return CompressedTensorsW8A8Fp8MoEMethod(quant_config)
@@ -482,7 +494,7 @@ def apply(
482494
)
483495

484496

485-
class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
497+
class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
486498

487499
def __init__(
488500
self,
@@ -823,3 +835,215 @@ def apply(
823835
sort_indices2=layer.w2_g_idx_sort_indices,
824836
num_bits=self.num_bits,
825837
is_k_full=self.is_k_full)
838+
839+
840+
class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
841+
842+
def __init__(
843+
self,
844+
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
845+
):
846+
self.quant_config = quant_config
847+
# TODO: @dsikka: refactor this to use schemes as other kernels
848+
# are supported + check if the layer is being ignored.
849+
config = self.quant_config.target_scheme_map["Linear"].get("weights")
850+
self.num_bits = config.num_bits
851+
self.packed_factor = 32 // config.num_bits
852+
self.strategy = config.strategy
853+
# channelwise is not supported by this kernel
854+
assert config.strategy == "group"
855+
self.group_size = config.group_size
856+
# grouped actorder isn't supported by this kernel
857+
assert config.actorder != "group"
858+
assert config.symmetric, (
859+
"Only symmetric quantization is supported for MoE")
860+
861+
if not (self.quant_config.quant_format
862+
== CompressionFormat.pack_quantized.value
863+
and self.num_bits in WNA16_SUPPORTED_BITS):
864+
raise ValueError("For Fused MoE layers, only ",
865+
f"{CompressionFormat.pack_quantized.value} ",
866+
"is supported for the following bits: ",
867+
f"{WNA16_SUPPORTED_BITS}")
868+
869+
def create_weights(self, layer: torch.nn.Module, num_experts: int,
870+
hidden_size: int, intermediate_size_per_partition: int,
871+
params_dtype: torch.dtype, **extra_weight_attrs):
872+
873+
# Will transpose the loaded weight along the
874+
# intermediate and hidden dim sizes. Will
875+
# shard for TP along the transposed dims
876+
extra_weight_attrs.update({
877+
"is_transposed": True,
878+
"quant_method": self.strategy
879+
})
880+
w13_weight = torch.nn.Parameter(torch.empty(
881+
num_experts,
882+
hidden_size // self.packed_factor,
883+
2 * intermediate_size_per_partition,
884+
dtype=torch.int32),
885+
requires_grad=False)
886+
layer.register_parameter("w13_weight_packed", w13_weight)
887+
set_weight_attrs(w13_weight, extra_weight_attrs)
888+
889+
w2_weight = torch.nn.Parameter(torch.empty(
890+
num_experts,
891+
intermediate_size_per_partition // self.packed_factor,
892+
hidden_size,
893+
dtype=torch.int32),
894+
requires_grad=False)
895+
layer.register_parameter("w2_weight_packed", w2_weight)
896+
set_weight_attrs(w2_weight, extra_weight_attrs)
897+
898+
w2_scales_size = intermediate_size_per_partition
899+
900+
if self.strategy == "channel":
901+
num_groups_w2 = num_groups_w13 = 1
902+
self.group_size = -1
903+
else:
904+
num_groups_w2 = w2_scales_size // self.group_size
905+
num_groups_w13 = hidden_size // self.group_size
906+
907+
w13_scale = torch.nn.Parameter(torch.ones(
908+
num_experts,
909+
num_groups_w13,
910+
2 * intermediate_size_per_partition,
911+
dtype=params_dtype),
912+
requires_grad=False)
913+
layer.register_parameter("w13_weight_scale", w13_scale)
914+
set_weight_attrs(w13_scale, extra_weight_attrs)
915+
916+
w2_scale = torch.nn.Parameter(torch.ones(num_experts,
917+
num_groups_w2,
918+
hidden_size,
919+
dtype=params_dtype),
920+
requires_grad=False)
921+
layer.register_parameter("w2_weight_scale", w2_scale)
922+
set_weight_attrs(w2_scale, extra_weight_attrs)
923+
set_weight_attrs(w2_scale, {"load_full_w2": False})
924+
925+
w2_weight_shape = torch.nn.Parameter(torch.empty(num_experts, 2),
926+
requires_grad=False)
927+
layer.register_parameter("w2_weight_shape", w2_weight_shape)
928+
set_weight_attrs(w2_weight_shape, extra_weight_attrs)
929+
w13_weight_shape = torch.nn.Parameter(torch.empty(num_experts, 2),
930+
requires_grad=False)
931+
932+
layer.register_parameter("w13_weight_shape", w13_weight_shape)
933+
set_weight_attrs(w13_weight_shape, extra_weight_attrs)
934+
935+
w13_g_idx = torch.nn.Parameter(
936+
torch.empty(
937+
num_experts,
938+
hidden_size,
939+
dtype=torch.int32,
940+
),
941+
requires_grad=False,
942+
)
943+
layer.register_parameter("w13_weight_g_idx", w13_g_idx)
944+
set_weight_attrs(w13_g_idx, extra_weight_attrs)
945+
946+
w2_g_idx = torch.nn.Parameter(
947+
torch.empty(
948+
num_experts,
949+
intermediate_size_per_partition,
950+
dtype=torch.int32,
951+
),
952+
requires_grad=False,
953+
)
954+
layer.register_parameter("w2_weight_g_idx", w2_g_idx)
955+
set_weight_attrs(w2_g_idx, extra_weight_attrs)
956+
957+
w13_g_idx_sort_indices = torch.nn.Parameter(
958+
torch.empty(
959+
num_experts,
960+
hidden_size,
961+
dtype=torch.int32,
962+
),
963+
requires_grad=False,
964+
)
965+
layer.register_parameter("w13_g_idx_sort_indices",
966+
w13_g_idx_sort_indices)
967+
set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs)
968+
969+
w2_g_idx_sort_indices = torch.nn.Parameter(
970+
torch.empty(
971+
num_experts,
972+
intermediate_size_per_partition,
973+
dtype=torch.int32,
974+
),
975+
requires_grad=False,
976+
)
977+
layer.register_parameter("w2_g_idx_sort_indices",
978+
w2_g_idx_sort_indices)
979+
set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)
980+
981+
layer.a13_scale = None
982+
layer.a2_scale = None
983+
984+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
985+
# Reconfigure packed weights and scales to match moe_wna16 format
986+
layer.w13_weight_packed = torch.nn.Parameter(
987+
layer.w13_weight_packed.transpose(1, 2).contiguous().view(
988+
torch.uint8),
989+
requires_grad=False)
990+
layer.w2_weight_packed = torch.nn.Parameter(
991+
layer.w2_weight_packed.transpose(1,
992+
2).contiguous().view(torch.uint8),
993+
requires_grad=False)
994+
layer.w13_weight_scale = torch.nn.Parameter(
995+
layer.w13_weight_scale.transpose(1, 2).contiguous(),
996+
requires_grad=False)
997+
layer.w2_weight_scale = torch.nn.Parameter(
998+
layer.w2_weight_scale.transpose(1, 2).contiguous(),
999+
requires_grad=False)
1000+
1001+
def apply(
1002+
self,
1003+
layer: torch.nn.Module,
1004+
x: torch.Tensor,
1005+
router_logits: torch.Tensor,
1006+
top_k: int,
1007+
renormalize: bool,
1008+
use_grouped_topk: bool = False,
1009+
topk_group: Optional[int] = None,
1010+
num_expert_group: Optional[int] = None,
1011+
global_num_experts: int = -1,
1012+
expert_map: Optional[torch.Tensor] = None,
1013+
custom_routing_function: Optional[Callable] = None,
1014+
scoring_func: str = "softmax",
1015+
e_score_correction_bias: Optional[torch.Tensor] = None,
1016+
apply_router_weight_on_input: bool = False,
1017+
activation: str = "silu",
1018+
) -> torch.Tensor:
1019+
from vllm.model_executor.layers.fused_moe import fused_experts
1020+
assert activation == "silu", "Only SiLU activation is supported."
1021+
topk_weights, topk_ids = FusedMoE.select_experts(
1022+
hidden_states=x,
1023+
router_logits=router_logits,
1024+
use_grouped_topk=use_grouped_topk,
1025+
top_k=top_k,
1026+
renormalize=renormalize,
1027+
topk_group=topk_group,
1028+
num_expert_group=num_expert_group,
1029+
custom_routing_function=custom_routing_function,
1030+
scoring_func=scoring_func,
1031+
e_score_correction_bias=e_score_correction_bias)
1032+
1033+
return fused_experts(
1034+
x,
1035+
layer.w13_weight_packed,
1036+
layer.w2_weight_packed,
1037+
topk_weights=topk_weights,
1038+
topk_ids=topk_ids,
1039+
inplace=True,
1040+
use_int4_w4a16=self.num_bits == 4,
1041+
use_int8_w8a16=self.num_bits == 8,
1042+
global_num_experts=global_num_experts,
1043+
apply_router_weight_on_input=apply_router_weight_on_input,
1044+
expert_map=expert_map,
1045+
w1_scale=layer.w13_weight_scale,
1046+
w2_scale=layer.w2_weight_scale,
1047+
w1_zp=None,
1048+
w2_zp=None,
1049+
block_shape=[0, self.group_size])

0 commit comments

Comments
 (0)