37
37
from vllm .model_executor .layers .quantization .utils .fp8_utils import (
38
38
dynamic_quant ,
39
39
dequant_block_fp8_weight_naive ,
40
- apply_block_fp8_linear_hpu_dynamic )
40
+ apply_block_fp8_linear_hpu_dynamic ,
41
+ apply_block_fp8_linear_hpu_dequant )
41
42
42
43
if current_platform .is_hpu ():
43
44
import habana_frameworks .torch as htorch
@@ -58,6 +59,7 @@ def __init__(
58
59
ignored_layers : Optional [List [str ]] = None ,
59
60
weight_block_size : Optional [List [int ]] = None ,
60
61
) -> None :
62
+ self .enable_runtime_dequant = os .environ .get ("VLLM_ENABLE_RUNTIME_DEQUANT" , "0" ) in ["1" , "true" ]
61
63
self .is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
62
64
if is_checkpoint_fp8_serialized :
63
65
logger .warning ("Detected fp8 checkpoint. Please note that the "
@@ -282,17 +284,24 @@ def process_weights_after_loading(self, layer: Module) -> None:
282
284
layer .weight .data ,
283
285
layer .weight_scale_inv .data ,
284
286
self .quant_config .weight_block_size )
285
- weight , weight_scale_inv = dynamic_quant (dequant_block_fp8_weight_naive (
286
- weight ,
287
- layer .weight_scale_inv .data ,
288
- self .quant_config .weight_block_size ,
289
- original_M = orig_M ,
290
- original_N = orig_N ,
291
- do_unpad = True ))
292
- weight_scale_inv = weight_scale_inv .squeeze (- 1 )
293
- layer .weight .data .copy_ (weight )
294
- layer .weight_scale_inv = Parameter (weight_scale_inv ,
295
- requires_grad = False )
287
+ if self .quant_config .enable_runtime_dequant :
288
+ layer .weight = torch .nn .Parameter (weight , requires_grad = False )
289
+ orig_M = torch .nn .Parameter (torch .tensor (orig_M , dtype = torch .int32 ), requires_grad = False )
290
+ orig_N = torch .nn .Parameter (torch .tensor (orig_N , dtype = torch .int32 ), requires_grad = False )
291
+ layer .register_parameter ("orig_M" , orig_M )
292
+ layer .register_parameter ("orig_N" , orig_N )
293
+ else :
294
+ weight , weight_scale_inv = dynamic_quant (dequant_block_fp8_weight_naive (
295
+ weight ,
296
+ layer .weight_scale_inv .data ,
297
+ self .quant_config .weight_block_size ,
298
+ original_M = orig_M ,
299
+ original_N = orig_N ,
300
+ do_unpad = True ))
301
+ weight_scale_inv = weight_scale_inv .squeeze (- 1 )
302
+ layer .weight .data .copy_ (weight )
303
+ layer .weight_scale_inv = Parameter (weight_scale_inv ,
304
+ requires_grad = False )
296
305
return
297
306
if current_platform .is_rocm ():
298
307
weight , weight_scale_inv , _ = \
@@ -404,13 +413,26 @@ def apply(self,
404
413
if self .block_quant :
405
414
assert self .quant_config .weight_block_size is not None
406
415
if current_platform .is_hpu ():
407
- return apply_block_fp8_linear_hpu_dynamic (
408
- input = x ,
409
- weight = layer .weight ,
410
- weight_scale = layer .weight_scale_inv ,
411
- input_scale = layer .input_scale ,
412
- bias = bias ,
413
- )
416
+ if self .quant_config .enable_runtime_dequant :
417
+ return apply_block_fp8_linear_hpu_dequant (
418
+ input = x ,
419
+ weight = layer .weight ,
420
+ block_size = self .quant_config .weight_block_size ,
421
+ weight_scale = layer .weight_scale_inv ,
422
+ input_scale = layer .input_scale ,
423
+ bias = bias ,
424
+ original_M = layer .orig_M ,
425
+ original_N = layer .orig_N ,
426
+ do_unpad = True ,
427
+ )
428
+ else :
429
+ return apply_block_fp8_linear_hpu_dynamic (
430
+ input = x ,
431
+ weight = layer .weight ,
432
+ weight_scale = layer .weight_scale_inv ,
433
+ input_scale = layer .input_scale ,
434
+ bias = bias ,
435
+ )
414
436
return apply_w8a8_block_fp8_linear (
415
437
input = x ,
416
438
weight = layer .weight ,
@@ -615,6 +637,8 @@ def process_weights_after_loading(self, layer: Module) -> None:
615
637
# TODO (rob): refactor block quant into separate class.
616
638
if self .block_quant :
617
639
if current_platform .is_hpu ():
640
+ if self .quant_config .enable_runtime_dequant :
641
+ return
618
642
w13_weight , w13_weight_scale_inv = dynamic_quant (dequant_block_fp8_weight_naive (
619
643
layer .w13_weight .data ,
620
644
layer .w13_weight_scale_inv .data ,
@@ -946,6 +970,7 @@ def do_dynamic_moe_with_static_scaling(x, topk_ids, topk_weights, w13_weight_fp8
946
970
activation = "silu" ,
947
971
experts_min = min_expert + ep_shift ,
948
972
experts_max = max_expert - 1 + ep_shift )
973
+ htorch .core .mark_step ()
949
974
if i == 0 :
950
975
final_hidden_states = current_hidden_states
951
976
else :
@@ -976,6 +1001,40 @@ def do_dynamic_moe_with_dynamic_scaling(x, topk_ids, topk_weights, w13_weight_fp
976
1001
activation = "silu" ,
977
1002
experts_min = min_expert + ep_shift ,
978
1003
experts_max = max_expert - 1 + ep_shift )
1004
+ htorch .core .mark_step ()
1005
+ if i == 0 :
1006
+ final_hidden_states = current_hidden_states
1007
+ else :
1008
+ final_hidden_states .add_ (current_hidden_states )
1009
+ return final_hidden_states
1010
+
1011
+ def do_dynamic_moe_with_dequant (x , topk_ids , topk_weights , w13_weight_fp8 , w2_weight_fp8 , moe_n_slice , n_expert_slice , w13_weight_scale_inv_fp8 = None , w2_weight_scale_inv_fp8 = None ):
1012
+ w13_weight = dequant_block_fp8_weight_naive (w13_weight_fp8 ,
1013
+ w13_weight_scale_inv_fp8 ,
1014
+ block_size = self .quant_config .weight_block_size ,
1015
+ dtype = x .dtype )
1016
+ w2_weight = dequant_block_fp8_weight_naive (w2_weight_fp8 ,
1017
+ w2_weight_scale_inv_fp8 ,
1018
+ block_size = self .quant_config .weight_block_size ,
1019
+ dtype = x .dtype )
1020
+ for i in range (moe_n_slice ):
1021
+ min_expert = i * n_expert_slice
1022
+ max_expert = (i + 1 ) * n_expert_slice
1023
+
1024
+ w13_list_slice = [w13_weight [j , ...] for j in range (min_expert , max_expert )]
1025
+ w2_list_slice = [w2_weight [j , ...] for j in range (min_expert , max_expert )]
1026
+
1027
+ current_hidden_states = torch .ops .hpu .mixture_of_experts (
1028
+ hidden_states = x ,
1029
+ expert_routing_table = topk_ids .to (torch .int64 ),
1030
+ router_weights = topk_weights .to (x .dtype ),
1031
+ w12 = w13_list_slice ,
1032
+ w3 = w2_list_slice ,
1033
+ permuted_weights = True ,
1034
+ activation = "silu" ,
1035
+ experts_min = min_expert + ep_shift ,
1036
+ experts_max = max_expert - 1 + ep_shift )
1037
+ htorch .core .mark_step ()
979
1038
if i == 0 :
980
1039
final_hidden_states = current_hidden_states
981
1040
else :
@@ -1003,7 +1062,9 @@ def do_dynamic_moe_with_dynamic_scaling(x, topk_ids, topk_weights, w13_weight_fp
1003
1062
moe_n_slice = self .moe_n_slice
1004
1063
1005
1064
if self .quant_config .activation_scheme == "dynamic" :
1006
- if not use_static_moe and self .enable_dmoe_dynamic_scale :
1065
+ if self .quant_config .enable_runtime_dequant :
1066
+ final_hidden_states = do_dynamic_moe_with_dequant (x , topk_ids , topk_weights , w13_weight_fp8 , w2_weight_fp8 , moe_n_slice , n_expert_slice , w13_weight_scale_inv_fp8 , w2_weight_scale_inv_fp8 )
1067
+ elif not use_static_moe and self .enable_dmoe_dynamic_scale :
1007
1068
final_hidden_states = do_dynamic_moe_with_dynamic_scaling (x , topk_ids , topk_weights , w13_weight_fp8 , w2_weight_fp8 , moe_n_slice , n_expert_slice , w13_weight_scale_inv_fp8 , w2_weight_scale_inv_fp8 )
1008
1069
else :
1009
1070
final_hidden_states = do_static_moe_with_dynamic_scaling (x , topk_ids , topk_weights , w13_weight_fp8 , w2_weight_fp8 , actual_total_experts , actual_num_experts , w13_weight_scale_inv_fp8 , w2_weight_scale_inv_fp8 )
0 commit comments