Skip to content

Commit 18b4fbe

Browse files
rasmithlk-chen
authored andcommitted
[AMD][FP8][BugFix] Remove V1 check in arg_utils.py for FP8 since it is not necessary (vllm-project#17215)
Signed-off-by: Randall Smith <[email protected]>
1 parent 1918bea commit 18b4fbe

File tree

2 files changed

+0
-29
lines changed

2 files changed

+0
-29
lines changed

vllm/engine/arg_utils.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1368,23 +1368,6 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
13681368
recommend_to_remove=False)
13691369
return False
13701370

1371-
if current_platform.is_rocm():
1372-
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
1373-
load_config = self.create_load_config()
1374-
quantization_config = VllmConfig.get_quantization_config(
1375-
model_config, load_config)
1376-
if isinstance(quantization_config, Fp8Config):
1377-
_raise_or_fallback(feature_name="fp8 for ROCm",
1378-
recommend_to_remove=False)
1379-
return False
1380-
from vllm.model_executor.layers.quantization.quark.quark import (
1381-
QuarkConfig)
1382-
1383-
if isinstance(quantization_config, QuarkConfig
1384-
) and quantization_config.has_fp8_layer_weights():
1385-
_raise_or_fallback(feature_name="Quark fp8 for ROCm",
1386-
recommend_to_remove=False)
1387-
13881371
# No Fp8 KV cache so far.
13891372
if self.kv_cache_dtype != "auto":
13901373
fp8_attention = self.kv_cache_dtype.startswith("fp8")

vllm/model_executor/layers/quantization/quark/quark.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -307,18 +307,6 @@ def get_cache_scale(self, name: str) -> Optional[str]:
307307
# If no matches, return None
308308
return None
309309

310-
def has_fp8_layer_weights(self):
311-
layer_quant_config = self.quant_config.get("layer_quant_config")
312-
to_dict = lambda obj: cast(Dict[str, Any], obj) or {}
313-
return any([
314-
'fp8' in cast(
315-
str,
316-
to_dict(
317-
to_dict(to_dict(layer_quant_config).get(layer_name)).get(
318-
"weight")).get("dtype"))
319-
for layer_name in ["*v_proj", "*k_proj", "*q_proj"]
320-
])
321-
322310

323311
class QuarkLinearMethod(LinearMethodBase):
324312

0 commit comments

Comments
 (0)