File tree 2 files changed +14
-1
lines changed
layers/quantization/quark
2 files changed +14
-1
lines changed Original file line number Diff line number Diff line change @@ -158,6 +158,19 @@ def _check_scheme_supported(self,
158
158
else :
159
159
return False
160
160
161
+ def is_fp8_w8a8 (self ) -> bool :
162
+ # Returns True if all quantized layers in model are fp8 w8a8
163
+ global_quant_config = cast (
164
+ Dict [str , Any ], self .quant_config .get ("global_quant_config" ))
165
+ layer_quant_configs = cast (Dict [str , Any ],
166
+ self .quant_config .get ("layer_quant_config" ))
167
+ for config in (global_quant_config , * layer_quant_configs .values ()):
168
+ weight_config = cast (Dict [str , Any ], config .get ("weight" ))
169
+ input_config = cast (Dict [str , Any ], config .get ("input_tensors" ))
170
+ if not self ._is_fp8_w8a8 (weight_config , input_config ):
171
+ return False
172
+ return True
173
+
161
174
def _is_fp8_w8a8 (self , weight_quant : Optional [Dict [str , Any ]],
162
175
input_quant : Optional [Dict [str , Any ]]) -> bool :
163
176
# Confirm weights and input quantized.
Original file line number Diff line number Diff line change @@ -250,7 +250,7 @@ def __init__(
250
250
self .hidden_size = config .hidden_size
251
251
self .use_fp8 = (isinstance (quant_config , Fp8Config ) or
252
252
(isinstance (quant_config , QuarkConfig )
253
- and quant_config ._is_fp8_w8a8 ())
253
+ and quant_config .is_fp8_w8a8 ())
254
254
if current_platform .is_fp8_fnuz () else False )
255
255
rope_theta = getattr (config , "rope_theta" , 10000 )
256
256
rope_scaling = getattr (config , "rope_scaling" , None )
You can’t perform that action at this time.
0 commit comments