Skip to content

Commit 0b8eaec

Browse files
committed
Re-fix Quark API
1 parent 29241ca commit 0b8eaec

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

vllm/model_executor/layers/quantization/quark/quark.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,19 @@ def _check_scheme_supported(self,
158158
else:
159159
return False
160160

161+
def is_fp8_w8a8(self) -> bool:
162+
# Returns True if all quantized layers in model are fp8 w8a8
163+
global_quant_config = cast(
164+
Dict[str, Any], self.quant_config.get("global_quant_config"))
165+
layer_quant_configs = cast(Dict[str, Any],
166+
self.quant_config.get("layer_quant_config"))
167+
for config in (global_quant_config, *layer_quant_configs.values()):
168+
weight_config = cast(Dict[str, Any], config.get("weight"))
169+
input_config = cast(Dict[str, Any], config.get("input_tensors"))
170+
if not self._is_fp8_w8a8(weight_config, input_config):
171+
return False
172+
return True
173+
161174
def _is_fp8_w8a8(self, weight_quant: Optional[Dict[str, Any]],
162175
input_quant: Optional[Dict[str, Any]]) -> bool:
163176
# Confirm weights and input quantized.

vllm/model_executor/models/llama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ def __init__(
250250
self.hidden_size = config.hidden_size
251251
self.use_fp8 = (isinstance(quant_config, Fp8Config) or
252252
(isinstance(quant_config, QuarkConfig)
253-
and quant_config._is_fp8_w8a8())
253+
and quant_config.is_fp8_w8a8())
254254
if current_platform.is_fp8_fnuz() else False)
255255
rope_theta = getattr(config, "rope_theta", 10000)
256256
rope_scaling = getattr(config, "rope_scaling", None)

0 commit comments

Comments
 (0)