@@ -323,13 +323,18 @@ def create_weights(
323
323
params_dtype : torch .dtype ,
324
324
** extra_weight_attrs ,
325
325
):
326
- # Currently assuming is_k_full is always True
327
- # (input size per partition is the same as full input size)
328
- # Supports only sym for now (no zp)
326
+ intermediate_size_full = extra_weight_attrs .pop (
327
+ "intermediate_size_full" )
328
+
329
+ self .is_k_full = (not self .quant_config .desc_act ) or (
330
+ intermediate_size_per_partition == intermediate_size_full )
331
+
329
332
if self .quant_config .group_size != - 1 :
330
333
scales_size13 = hidden_size // self .quant_config .group_size
331
- scales_size2 = (intermediate_size_per_partition //
332
- self .quant_config .group_size )
334
+ w2_scales_size = (intermediate_size_full
335
+ if self .quant_config .desc_act else
336
+ intermediate_size_per_partition )
337
+ scales_size2 = (w2_scales_size // self .quant_config .group_size )
333
338
strategy = FusedMoeWeightScaleSupported .GROUP .value
334
339
else :
335
340
scales_size13 = 1
@@ -385,6 +390,9 @@ def create_weights(
385
390
)
386
391
layer .register_parameter ("w2_scales" , w2_scales )
387
392
set_weight_attrs (w2_scales , extra_weight_attrs )
393
+ # dont shard the w2 scales when running act order
394
+ set_weight_attrs (w2_scales ,
395
+ {"load_full_w2" : self .quant_config .desc_act })
388
396
# up_proj scales
389
397
w13_qzeros = torch .nn .Parameter (
390
398
torch .empty (num_experts ,
@@ -406,6 +414,9 @@ def create_weights(
406
414
)
407
415
layer .register_parameter ("w2_qzeros" , w2_qzeros )
408
416
set_weight_attrs (w2_qzeros , extra_weight_attrs )
417
+ # dont shard the w2 scales when running act order
418
+ set_weight_attrs (w2_qzeros ,
419
+ {"load_full_w2" : self .quant_config .desc_act })
409
420
w13_g_idx = torch .nn .Parameter (
410
421
torch .empty (
411
422
num_experts ,
@@ -575,4 +586,4 @@ def apply(
575
586
sort_indices1 = layer .w13_g_idx_sort_indices ,
576
587
sort_indices2 = layer .w2_g_idx_sort_indices ,
577
588
num_bits = self .quant_config .quant_type .size_bits ,
578
- ).to (orig_dtype )
589
+ is_k_full = self . is_k_full ).to (orig_dtype )
0 commit comments