6
6
7
7
import torch
8
8
from compressed_tensors import CompressionFormat
9
- from compressed_tensors .quantization import QuantizationStrategy
9
+ from compressed_tensors .quantization import (ActivationOrdering ,
10
+ QuantizationStrategy )
10
11
11
12
import vllm .model_executor .layers .fused_moe # noqa
12
13
from vllm import _custom_ops as ops
@@ -30,9 +31,11 @@ class GPTQMarlinState(Enum):
30
31
31
32
32
33
__all__ = [
33
- "CompressedTensorsMoEMethod" , "CompressedTensorsW8A8Fp8MoEMethod" ,
34
+ "CompressedTensorsMoEMethod" ,
35
+ "CompressedTensorsW8A8Fp8MoEMethod" ,
34
36
"CompressedTensorsW8A8Fp8MoECutlassMethod" ,
35
- "CompressedTensorsWNA16MoEMethod"
37
+ "CompressedTensorsWNA16MarlinMoEMethod" ,
38
+ "CompressedTensorsWNA16MoEMethod" ,
36
39
]
37
40
38
41
@@ -41,8 +44,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
41
44
@staticmethod
42
45
def get_moe_method (
43
46
quant_config : "CompressedTensorsConfig" , # type: ignore # noqa E501
44
- activation : str ,
45
- expert_map : Optional [torch .Tensor ],
47
+ layer : torch .nn .Module ,
46
48
) -> "CompressedTensorsMoEMethod" :
47
49
# TODO: @dsikka: refactor this to use schemes as other kernels
48
50
# are supported + check if the layer is being ignored.
@@ -51,9 +53,21 @@ def get_moe_method(
51
53
"input_activations" )
52
54
53
55
if quant_config ._is_wNa16_group_channel (weight_quant , input_quant ):
54
- return CompressedTensorsWNA16MoEMethod (quant_config )
56
+ # Prefer to use the non-marlin kernel when:
57
+ # 1. Many experts (MarlinMoE gives poor performance when >= 16)
58
+ # 2. Non-FP16 dtype (MarlinMoE only supports FP16)
59
+ # 3. Actorder is not group/dynamic (g_idx is unsupported)
60
+ # 4. Scaled are grouped (channelwise is unsupported)
61
+ if ((layer .local_num_experts >= 16
62
+ or layer .params_dtype != torch .float16 ) and
63
+ weight_quant .actorder not in (ActivationOrdering .GROUP ,
64
+ ActivationOrdering .DYNAMIC )
65
+ and weight_quant .strategy in QuantizationStrategy .GROUP ):
66
+ return CompressedTensorsWNA16MoEMethod (quant_config )
67
+ else :
68
+ return CompressedTensorsWNA16MarlinMoEMethod (quant_config )
55
69
elif (quant_config ._is_fp8_w8a8_sm90 (weight_quant , input_quant )
56
- and activation == "silu" and expert_map is None ):
70
+ and layer . activation == "silu" and layer . expert_map is None ):
57
71
return CompressedTensorsW8A8Fp8MoECutlassMethod (quant_config )
58
72
elif quant_config ._is_fp8_w8a8 (weight_quant , input_quant ):
59
73
return CompressedTensorsW8A8Fp8MoEMethod (quant_config )
@@ -482,7 +496,7 @@ def apply(
482
496
)
483
497
484
498
485
- class CompressedTensorsWNA16MoEMethod (CompressedTensorsMoEMethod ):
499
+ class CompressedTensorsWNA16MarlinMoEMethod (CompressedTensorsMoEMethod ):
486
500
487
501
def __init__ (
488
502
self ,
@@ -823,3 +837,215 @@ def apply(
823
837
sort_indices2 = layer .w2_g_idx_sort_indices ,
824
838
num_bits = self .num_bits ,
825
839
is_k_full = self .is_k_full )
840
+
841
+
842
+ class CompressedTensorsWNA16MoEMethod (CompressedTensorsMoEMethod ):
843
+
844
+ def __init__ (
845
+ self ,
846
+ quant_config : "CompressedTensorsConfig" # type: ignore # noqa E501
847
+ ):
848
+ self .quant_config = quant_config
849
+ # TODO: @dsikka: refactor this to use schemes as other kernels
850
+ # are supported + check if the layer is being ignored.
851
+ config = self .quant_config .target_scheme_map ["Linear" ].get ("weights" )
852
+ self .num_bits = config .num_bits
853
+ self .packed_factor = 32 // config .num_bits
854
+ self .strategy = config .strategy
855
+ # channelwise is not supported by this kernel
856
+ assert config .strategy == "group"
857
+ self .group_size = config .group_size
858
+ # grouped actorder isn't supported by this kernel
859
+ assert config .actorder != "group"
860
+ assert config .symmetric , (
861
+ "Only symmetric quantization is supported for MoE" )
862
+
863
+ if not (self .quant_config .quant_format
864
+ == CompressionFormat .pack_quantized .value
865
+ and self .num_bits in WNA16_SUPPORTED_BITS ):
866
+ raise ValueError ("For Fused MoE layers, only " ,
867
+ f"{ CompressionFormat .pack_quantized .value } " ,
868
+ "is supported for the following bits: " ,
869
+ f"{ WNA16_SUPPORTED_BITS } " )
870
+
871
+ def create_weights (self , layer : torch .nn .Module , num_experts : int ,
872
+ hidden_size : int , intermediate_size_per_partition : int ,
873
+ params_dtype : torch .dtype , ** extra_weight_attrs ):
874
+
875
+ # Will transpose the loaded weight along the
876
+ # intermediate and hidden dim sizes. Will
877
+ # shard for TP along the transposed dims
878
+ extra_weight_attrs .update ({
879
+ "is_transposed" : True ,
880
+ "quant_method" : self .strategy
881
+ })
882
+ w13_weight = torch .nn .Parameter (torch .empty (
883
+ num_experts ,
884
+ hidden_size // self .packed_factor ,
885
+ 2 * intermediate_size_per_partition ,
886
+ dtype = torch .int32 ),
887
+ requires_grad = False )
888
+ layer .register_parameter ("w13_weight_packed" , w13_weight )
889
+ set_weight_attrs (w13_weight , extra_weight_attrs )
890
+
891
+ w2_weight = torch .nn .Parameter (torch .empty (
892
+ num_experts ,
893
+ intermediate_size_per_partition // self .packed_factor ,
894
+ hidden_size ,
895
+ dtype = torch .int32 ),
896
+ requires_grad = False )
897
+ layer .register_parameter ("w2_weight_packed" , w2_weight )
898
+ set_weight_attrs (w2_weight , extra_weight_attrs )
899
+
900
+ w2_scales_size = intermediate_size_per_partition
901
+
902
+ if self .strategy == "channel" :
903
+ num_groups_w2 = num_groups_w13 = 1
904
+ self .group_size = - 1
905
+ else :
906
+ num_groups_w2 = w2_scales_size // self .group_size
907
+ num_groups_w13 = hidden_size // self .group_size
908
+
909
+ w13_scale = torch .nn .Parameter (torch .ones (
910
+ num_experts ,
911
+ num_groups_w13 ,
912
+ 2 * intermediate_size_per_partition ,
913
+ dtype = params_dtype ),
914
+ requires_grad = False )
915
+ layer .register_parameter ("w13_weight_scale" , w13_scale )
916
+ set_weight_attrs (w13_scale , extra_weight_attrs )
917
+
918
+ w2_scale = torch .nn .Parameter (torch .ones (num_experts ,
919
+ num_groups_w2 ,
920
+ hidden_size ,
921
+ dtype = params_dtype ),
922
+ requires_grad = False )
923
+ layer .register_parameter ("w2_weight_scale" , w2_scale )
924
+ set_weight_attrs (w2_scale , extra_weight_attrs )
925
+ set_weight_attrs (w2_scale , {"load_full_w2" : False })
926
+
927
+ w2_weight_shape = torch .nn .Parameter (torch .empty (num_experts , 2 ),
928
+ requires_grad = False )
929
+ layer .register_parameter ("w2_weight_shape" , w2_weight_shape )
930
+ set_weight_attrs (w2_weight_shape , extra_weight_attrs )
931
+ w13_weight_shape = torch .nn .Parameter (torch .empty (num_experts , 2 ),
932
+ requires_grad = False )
933
+
934
+ layer .register_parameter ("w13_weight_shape" , w13_weight_shape )
935
+ set_weight_attrs (w13_weight_shape , extra_weight_attrs )
936
+
937
+ w13_g_idx = torch .nn .Parameter (
938
+ torch .empty (
939
+ num_experts ,
940
+ hidden_size ,
941
+ dtype = torch .int32 ,
942
+ ),
943
+ requires_grad = False ,
944
+ )
945
+ layer .register_parameter ("w13_weight_g_idx" , w13_g_idx )
946
+ set_weight_attrs (w13_g_idx , extra_weight_attrs )
947
+
948
+ w2_g_idx = torch .nn .Parameter (
949
+ torch .empty (
950
+ num_experts ,
951
+ intermediate_size_per_partition ,
952
+ dtype = torch .int32 ,
953
+ ),
954
+ requires_grad = False ,
955
+ )
956
+ layer .register_parameter ("w2_weight_g_idx" , w2_g_idx )
957
+ set_weight_attrs (w2_g_idx , extra_weight_attrs )
958
+
959
+ w13_g_idx_sort_indices = torch .nn .Parameter (
960
+ torch .empty (
961
+ num_experts ,
962
+ hidden_size ,
963
+ dtype = torch .int32 ,
964
+ ),
965
+ requires_grad = False ,
966
+ )
967
+ layer .register_parameter ("w13_g_idx_sort_indices" ,
968
+ w13_g_idx_sort_indices )
969
+ set_weight_attrs (w13_g_idx_sort_indices , extra_weight_attrs )
970
+
971
+ w2_g_idx_sort_indices = torch .nn .Parameter (
972
+ torch .empty (
973
+ num_experts ,
974
+ intermediate_size_per_partition ,
975
+ dtype = torch .int32 ,
976
+ ),
977
+ requires_grad = False ,
978
+ )
979
+ layer .register_parameter ("w2_g_idx_sort_indices" ,
980
+ w2_g_idx_sort_indices )
981
+ set_weight_attrs (w2_g_idx_sort_indices , extra_weight_attrs )
982
+
983
+ layer .a13_scale = None
984
+ layer .a2_scale = None
985
+
986
+ def process_weights_after_loading (self , layer : torch .nn .Module ) -> None :
987
+ # Reconfigure packed weights and scales to match moe_wna16 format
988
+ layer .w13_weight_packed = torch .nn .Parameter (
989
+ layer .w13_weight_packed .transpose (1 , 2 ).contiguous ().view (
990
+ torch .uint8 ),
991
+ requires_grad = False )
992
+ layer .w2_weight_packed = torch .nn .Parameter (
993
+ layer .w2_weight_packed .transpose (1 ,
994
+ 2 ).contiguous ().view (torch .uint8 ),
995
+ requires_grad = False )
996
+ layer .w13_weight_scale = torch .nn .Parameter (
997
+ layer .w13_weight_scale .transpose (1 , 2 ).contiguous (),
998
+ requires_grad = False )
999
+ layer .w2_weight_scale = torch .nn .Parameter (
1000
+ layer .w2_weight_scale .transpose (1 , 2 ).contiguous (),
1001
+ requires_grad = False )
1002
+
1003
+ def apply (
1004
+ self ,
1005
+ layer : torch .nn .Module ,
1006
+ x : torch .Tensor ,
1007
+ router_logits : torch .Tensor ,
1008
+ top_k : int ,
1009
+ renormalize : bool ,
1010
+ use_grouped_topk : bool = False ,
1011
+ topk_group : Optional [int ] = None ,
1012
+ num_expert_group : Optional [int ] = None ,
1013
+ global_num_experts : int = - 1 ,
1014
+ expert_map : Optional [torch .Tensor ] = None ,
1015
+ custom_routing_function : Optional [Callable ] = None ,
1016
+ scoring_func : str = "softmax" ,
1017
+ e_score_correction_bias : Optional [torch .Tensor ] = None ,
1018
+ apply_router_weight_on_input : bool = False ,
1019
+ activation : str = "silu" ,
1020
+ ) -> torch .Tensor :
1021
+ from vllm .model_executor .layers .fused_moe import fused_experts
1022
+ assert activation == "silu" , "Only SiLU activation is supported."
1023
+ topk_weights , topk_ids = FusedMoE .select_experts (
1024
+ hidden_states = x ,
1025
+ router_logits = router_logits ,
1026
+ use_grouped_topk = use_grouped_topk ,
1027
+ top_k = top_k ,
1028
+ renormalize = renormalize ,
1029
+ topk_group = topk_group ,
1030
+ num_expert_group = num_expert_group ,
1031
+ custom_routing_function = custom_routing_function ,
1032
+ scoring_func = scoring_func ,
1033
+ e_score_correction_bias = e_score_correction_bias )
1034
+
1035
+ return fused_experts (
1036
+ x ,
1037
+ layer .w13_weight_packed ,
1038
+ layer .w2_weight_packed ,
1039
+ topk_weights = topk_weights ,
1040
+ topk_ids = topk_ids ,
1041
+ inplace = True ,
1042
+ use_int4_w4a16 = self .num_bits == 4 ,
1043
+ use_int8_w8a16 = self .num_bits == 8 ,
1044
+ global_num_experts = global_num_experts ,
1045
+ apply_router_weight_on_input = apply_router_weight_on_input ,
1046
+ expert_map = expert_map ,
1047
+ w1_scale = layer .w13_weight_scale ,
1048
+ w2_scale = layer .w2_weight_scale ,
1049
+ w1_zp = None ,
1050
+ w2_zp = None ,
1051
+ block_shape = [0 , self .group_size ])
0 commit comments