@@ -247,6 +247,15 @@ def create_weights(
247
247
def process_weights_after_loading (self , layer : Module ) -> None :
248
248
# Block quant doesn't need to process weights after loading
249
249
if self .block_quant :
250
+ if current_platform .is_rocm ():
251
+ weight , weight_scale , _ = \
252
+ normalize_e4m3fn_to_e4m3fnuz (
253
+ weight = layer .weight ,
254
+ weight_scale = layer .weight_scale_inv ,
255
+ input_scale = layer .input_scale )
256
+ layer .weight = Parameter (weight , requires_grad = False )
257
+ layer .weight_scale_inv = Parameter (weight_scale ,
258
+ requires_grad = False )
250
259
return
251
260
layer .weight = torch .nn .Parameter (layer .weight .data ,
252
261
requires_grad = False )
@@ -495,6 +504,30 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
495
504
def process_weights_after_loading (self , layer : Module ) -> None :
496
505
# Block quant doesn't need to process weights after loading
497
506
if self .block_quant :
507
+ if current_platform .is_rocm ():
508
+ w13_weight , w13_weight_scale_inv , w13_input_scale = \
509
+ normalize_e4m3fn_to_e4m3fnuz (
510
+ layer .w13_weight , layer .w13_weight_scale_inv ,
511
+ layer .w13_input_scale )
512
+ w2_weight , w2_weight_scale_inv , w2_input_scale = \
513
+ normalize_e4m3fn_to_e4m3fnuz (
514
+ layer .w2_weight , layer .w2_weight_scale_inv ,
515
+ layer .w2_input_scale )
516
+ # Reset the parameter
517
+ layer .w13_weight = torch .nn .Parameter (w13_weight ,
518
+ requires_grad = False )
519
+ layer .w13_weight_scale_inv = torch .nn .Parameter (
520
+ w13_weight_scale_inv , requires_grad = False )
521
+ if w13_input_scale is not None :
522
+ layer .w13_input_scale = torch .nn .Parameter (
523
+ w13_input_scale , requires_grad = False )
524
+ layer .w2_weight = torch .nn .Parameter (w2_weight ,
525
+ requires_grad = False )
526
+ layer .w2_weight_scale_inv = torch .nn .Parameter (
527
+ w2_weight_scale_inv , requires_grad = False )
528
+ if w2_input_scale is not None :
529
+ layer .w2_input_scale = torch .nn .Parameter (
530
+ w2_input_scale , requires_grad = False )
498
531
return
499
532
# If checkpoint is fp16, quantize in place.
500
533
if not self .quant_config .is_checkpoint_fp8_serialized :
0 commit comments