@@ -30,9 +30,11 @@ class GPTQMarlinState(Enum):
30
30
31
31
32
32
__all__ = [
33
- "CompressedTensorsMoEMethod" , "CompressedTensorsW8A8Fp8MoEMethod" ,
33
+ "CompressedTensorsMoEMethod" ,
34
+ "CompressedTensorsW8A8Fp8MoEMethod" ,
34
35
"CompressedTensorsW8A8Fp8MoECutlassMethod" ,
35
- "CompressedTensorsWNA16MoEMethod"
36
+ "CompressedTensorsWNA16MarlinMoEMethod" ,
37
+ "CompressedTensorsWNA16MoEMethod" ,
36
38
]
37
39
38
40
@@ -41,8 +43,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
41
43
@staticmethod
42
44
def get_moe_method (
43
45
quant_config : "CompressedTensorsConfig" , # type: ignore # noqa E501
44
- activation : str ,
45
- expert_map : Optional [torch .Tensor ],
46
+ layer : torch .nn .Module ,
46
47
) -> "CompressedTensorsMoEMethod" :
47
48
# TODO: @dsikka: refactor this to use schemes as other kernels
48
49
# are supported + check if the layer is being ignored.
@@ -51,9 +52,20 @@ def get_moe_method(
51
52
"input_activations" )
52
53
53
54
if quant_config ._is_wNa16_group_channel (weight_quant , input_quant ):
54
- return CompressedTensorsWNA16MoEMethod (quant_config )
55
+ # Prefer to use the non-marlin kernel when:
56
+ # 1. Many experts (MarlinMoE gives poor performance when >= 16)
57
+ # 2. Non-FP16 dtype (MarlinMoE only supports FP16)
58
+ # 3. Actorder is not dynamic (g_idx is unsupported)
59
+ # 4. Scaled are grouped (channelwise is unsupported)
60
+ if ((layer .local_num_experts >= 16
61
+ or layer .params_dtype != torch .float16 )
62
+ and weight_quant .actorder != "group"
63
+ and weight_quant .strategy == "group" ):
64
+ return CompressedTensorsWNA16MoEMethod (quant_config )
65
+ else :
66
+ return CompressedTensorsWNA16MarlinMoEMethod (quant_config )
55
67
elif (quant_config ._is_fp8_w8a8_sm90 (weight_quant , input_quant )
56
- and activation == "silu" and expert_map is None ):
68
+ and layer . activation == "silu" and layer . expert_map is None ):
57
69
return CompressedTensorsW8A8Fp8MoECutlassMethod (quant_config )
58
70
elif quant_config ._is_fp8_w8a8 (weight_quant , input_quant ):
59
71
return CompressedTensorsW8A8Fp8MoEMethod (quant_config )
@@ -482,7 +494,7 @@ def apply(
482
494
)
483
495
484
496
485
- class CompressedTensorsWNA16MoEMethod (CompressedTensorsMoEMethod ):
497
+ class CompressedTensorsWNA16MarlinMoEMethod (CompressedTensorsMoEMethod ):
486
498
487
499
def __init__ (
488
500
self ,
@@ -823,3 +835,215 @@ def apply(
823
835
sort_indices2 = layer .w2_g_idx_sort_indices ,
824
836
num_bits = self .num_bits ,
825
837
is_k_full = self .is_k_full )
838
+
839
+
840
+ class CompressedTensorsWNA16MoEMethod (CompressedTensorsMoEMethod ):
841
+
842
+ def __init__ (
843
+ self ,
844
+ quant_config : "CompressedTensorsConfig" # type: ignore # noqa E501
845
+ ):
846
+ self .quant_config = quant_config
847
+ # TODO: @dsikka: refactor this to use schemes as other kernels
848
+ # are supported + check if the layer is being ignored.
849
+ config = self .quant_config .target_scheme_map ["Linear" ].get ("weights" )
850
+ self .num_bits = config .num_bits
851
+ self .packed_factor = 32 // config .num_bits
852
+ self .strategy = config .strategy
853
+ # channelwise is not supported by this kernel
854
+ assert config .strategy == "group"
855
+ self .group_size = config .group_size
856
+ # grouped actorder isn't supported by this kernel
857
+ assert config .actorder != "group"
858
+ assert config .symmetric , (
859
+ "Only symmetric quantization is supported for MoE" )
860
+
861
+ if not (self .quant_config .quant_format
862
+ == CompressionFormat .pack_quantized .value
863
+ and self .num_bits in WNA16_SUPPORTED_BITS ):
864
+ raise ValueError ("For Fused MoE layers, only " ,
865
+ f"{ CompressionFormat .pack_quantized .value } " ,
866
+ "is supported for the following bits: " ,
867
+ f"{ WNA16_SUPPORTED_BITS } " )
868
+
869
+ def create_weights (self , layer : torch .nn .Module , num_experts : int ,
870
+ hidden_size : int , intermediate_size_per_partition : int ,
871
+ params_dtype : torch .dtype , ** extra_weight_attrs ):
872
+
873
+ # Will transpose the loaded weight along the
874
+ # intermediate and hidden dim sizes. Will
875
+ # shard for TP along the transposed dims
876
+ extra_weight_attrs .update ({
877
+ "is_transposed" : True ,
878
+ "quant_method" : self .strategy
879
+ })
880
+ w13_weight = torch .nn .Parameter (torch .empty (
881
+ num_experts ,
882
+ hidden_size // self .packed_factor ,
883
+ 2 * intermediate_size_per_partition ,
884
+ dtype = torch .int32 ),
885
+ requires_grad = False )
886
+ layer .register_parameter ("w13_weight_packed" , w13_weight )
887
+ set_weight_attrs (w13_weight , extra_weight_attrs )
888
+
889
+ w2_weight = torch .nn .Parameter (torch .empty (
890
+ num_experts ,
891
+ intermediate_size_per_partition // self .packed_factor ,
892
+ hidden_size ,
893
+ dtype = torch .int32 ),
894
+ requires_grad = False )
895
+ layer .register_parameter ("w2_weight_packed" , w2_weight )
896
+ set_weight_attrs (w2_weight , extra_weight_attrs )
897
+
898
+ w2_scales_size = intermediate_size_per_partition
899
+
900
+ if self .strategy == "channel" :
901
+ num_groups_w2 = num_groups_w13 = 1
902
+ self .group_size = - 1
903
+ else :
904
+ num_groups_w2 = w2_scales_size // self .group_size
905
+ num_groups_w13 = hidden_size // self .group_size
906
+
907
+ w13_scale = torch .nn .Parameter (torch .ones (
908
+ num_experts ,
909
+ num_groups_w13 ,
910
+ 2 * intermediate_size_per_partition ,
911
+ dtype = params_dtype ),
912
+ requires_grad = False )
913
+ layer .register_parameter ("w13_weight_scale" , w13_scale )
914
+ set_weight_attrs (w13_scale , extra_weight_attrs )
915
+
916
+ w2_scale = torch .nn .Parameter (torch .ones (num_experts ,
917
+ num_groups_w2 ,
918
+ hidden_size ,
919
+ dtype = params_dtype ),
920
+ requires_grad = False )
921
+ layer .register_parameter ("w2_weight_scale" , w2_scale )
922
+ set_weight_attrs (w2_scale , extra_weight_attrs )
923
+ set_weight_attrs (w2_scale , {"load_full_w2" : False })
924
+
925
+ w2_weight_shape = torch .nn .Parameter (torch .empty (num_experts , 2 ),
926
+ requires_grad = False )
927
+ layer .register_parameter ("w2_weight_shape" , w2_weight_shape )
928
+ set_weight_attrs (w2_weight_shape , extra_weight_attrs )
929
+ w13_weight_shape = torch .nn .Parameter (torch .empty (num_experts , 2 ),
930
+ requires_grad = False )
931
+
932
+ layer .register_parameter ("w13_weight_shape" , w13_weight_shape )
933
+ set_weight_attrs (w13_weight_shape , extra_weight_attrs )
934
+
935
+ w13_g_idx = torch .nn .Parameter (
936
+ torch .empty (
937
+ num_experts ,
938
+ hidden_size ,
939
+ dtype = torch .int32 ,
940
+ ),
941
+ requires_grad = False ,
942
+ )
943
+ layer .register_parameter ("w13_weight_g_idx" , w13_g_idx )
944
+ set_weight_attrs (w13_g_idx , extra_weight_attrs )
945
+
946
+ w2_g_idx = torch .nn .Parameter (
947
+ torch .empty (
948
+ num_experts ,
949
+ intermediate_size_per_partition ,
950
+ dtype = torch .int32 ,
951
+ ),
952
+ requires_grad = False ,
953
+ )
954
+ layer .register_parameter ("w2_weight_g_idx" , w2_g_idx )
955
+ set_weight_attrs (w2_g_idx , extra_weight_attrs )
956
+
957
+ w13_g_idx_sort_indices = torch .nn .Parameter (
958
+ torch .empty (
959
+ num_experts ,
960
+ hidden_size ,
961
+ dtype = torch .int32 ,
962
+ ),
963
+ requires_grad = False ,
964
+ )
965
+ layer .register_parameter ("w13_g_idx_sort_indices" ,
966
+ w13_g_idx_sort_indices )
967
+ set_weight_attrs (w13_g_idx_sort_indices , extra_weight_attrs )
968
+
969
+ w2_g_idx_sort_indices = torch .nn .Parameter (
970
+ torch .empty (
971
+ num_experts ,
972
+ intermediate_size_per_partition ,
973
+ dtype = torch .int32 ,
974
+ ),
975
+ requires_grad = False ,
976
+ )
977
+ layer .register_parameter ("w2_g_idx_sort_indices" ,
978
+ w2_g_idx_sort_indices )
979
+ set_weight_attrs (w2_g_idx_sort_indices , extra_weight_attrs )
980
+
981
+ layer .a13_scale = None
982
+ layer .a2_scale = None
983
+
984
+ def process_weights_after_loading (self , layer : torch .nn .Module ) -> None :
985
+ # Reconfigure packed weights and scales to match moe_wna16 format
986
+ layer .w13_weight_packed = torch .nn .Parameter (
987
+ layer .w13_weight_packed .transpose (1 , 2 ).contiguous ().view (
988
+ torch .uint8 ),
989
+ requires_grad = False )
990
+ layer .w2_weight_packed = torch .nn .Parameter (
991
+ layer .w2_weight_packed .transpose (1 ,
992
+ 2 ).contiguous ().view (torch .uint8 ),
993
+ requires_grad = False )
994
+ layer .w13_weight_scale = torch .nn .Parameter (
995
+ layer .w13_weight_scale .transpose (1 , 2 ).contiguous (),
996
+ requires_grad = False )
997
+ layer .w2_weight_scale = torch .nn .Parameter (
998
+ layer .w2_weight_scale .transpose (1 , 2 ).contiguous (),
999
+ requires_grad = False )
1000
+
1001
+ def apply (
1002
+ self ,
1003
+ layer : torch .nn .Module ,
1004
+ x : torch .Tensor ,
1005
+ router_logits : torch .Tensor ,
1006
+ top_k : int ,
1007
+ renormalize : bool ,
1008
+ use_grouped_topk : bool = False ,
1009
+ topk_group : Optional [int ] = None ,
1010
+ num_expert_group : Optional [int ] = None ,
1011
+ global_num_experts : int = - 1 ,
1012
+ expert_map : Optional [torch .Tensor ] = None ,
1013
+ custom_routing_function : Optional [Callable ] = None ,
1014
+ scoring_func : str = "softmax" ,
1015
+ e_score_correction_bias : Optional [torch .Tensor ] = None ,
1016
+ apply_router_weight_on_input : bool = False ,
1017
+ activation : str = "silu" ,
1018
+ ) -> torch .Tensor :
1019
+ from vllm .model_executor .layers .fused_moe import fused_experts
1020
+ assert activation == "silu" , "Only SiLU activation is supported."
1021
+ topk_weights , topk_ids = FusedMoE .select_experts (
1022
+ hidden_states = x ,
1023
+ router_logits = router_logits ,
1024
+ use_grouped_topk = use_grouped_topk ,
1025
+ top_k = top_k ,
1026
+ renormalize = renormalize ,
1027
+ topk_group = topk_group ,
1028
+ num_expert_group = num_expert_group ,
1029
+ custom_routing_function = custom_routing_function ,
1030
+ scoring_func = scoring_func ,
1031
+ e_score_correction_bias = e_score_correction_bias )
1032
+
1033
+ return fused_experts (
1034
+ x ,
1035
+ layer .w13_weight_packed ,
1036
+ layer .w2_weight_packed ,
1037
+ topk_weights = topk_weights ,
1038
+ topk_ids = topk_ids ,
1039
+ inplace = True ,
1040
+ use_int4_w4a16 = self .num_bits == 4 ,
1041
+ use_int8_w8a16 = self .num_bits == 8 ,
1042
+ global_num_experts = global_num_experts ,
1043
+ apply_router_weight_on_input = apply_router_weight_on_input ,
1044
+ expert_map = expert_map ,
1045
+ w1_scale = layer .w13_weight_scale ,
1046
+ w2_scale = layer .w2_weight_scale ,
1047
+ w1_zp = None ,
1048
+ w2_zp = None ,
1049
+ block_shape = [0 , self .group_size ])
0 commit comments