Skip to content

Commit 187e329

Browse files
bjmsongmdattack
andauthored
[Bugfix] Change kv scaling factor by param json on nvidia gpu (#11688)
Signed-off-by: bjmsong <[email protected]> Co-authored-by: bjmsong <[email protected]>
1 parent b55ed6e commit 187e329

File tree

5 files changed

+14
-9
lines changed

5 files changed

+14
-9
lines changed

vllm/model_executor/models/exaone.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -606,8 +606,9 @@ def load_kv_cache_scales(self, quantization_param_path: str) -> None:
606606
# which is consistent with the practice of setting
607607
# scaling_factor = tensor_amax / FPtype_max
608608
scaling_factor *= 2
609-
if hasattr(layer_self_attn, "kv_scale"):
610-
layer_self_attn.attn._kv_scale = scaling_factor
609+
if hasattr(layer_self_attn.attn, "_k_scale"):
610+
layer_self_attn.attn._k_scale = scaling_factor
611+
layer_self_attn.attn._v_scale = scaling_factor
611612
else:
612613
raise RuntimeError("Self attention has no KV cache scaling "
613614
"factor attribute!")

vllm/model_executor/models/granite.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -545,8 +545,9 @@ def load_kv_cache_scales(self, quantization_param_path: str) -> None:
545545
# which is consistent with the practice of setting
546546
# scaling_factor = tensor_amax / FPtype_max
547547
scaling_factor *= 2
548-
if hasattr(layer_self_attn, "kv_scale"):
549-
layer_self_attn.attn._kv_scale = scaling_factor
548+
if hasattr(layer_self_attn.attn, "_k_scale"):
549+
layer_self_attn.attn._k_scale = scaling_factor
550+
layer_self_attn.attn._v_scale = scaling_factor
550551
else:
551552
raise RuntimeError("Self attention has no KV cache scaling "
552553
"factor attribute!")

vllm/model_executor/models/llama.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -452,8 +452,9 @@ def load_kv_cache_scales(self, quantization_param_path: str) -> None:
452452
# which is consistent with the practice of setting
453453
# scaling_factor = tensor_amax / FPtype_max
454454
scaling_factor *= 2
455-
if hasattr(layer_self_attn, "kv_scale"):
456-
layer_self_attn.attn._kv_scale = scaling_factor
455+
if hasattr(layer_self_attn.attn, "_k_scale"):
456+
layer_self_attn.attn._k_scale = scaling_factor
457+
layer_self_attn.attn._v_scale = scaling_factor
457458
else:
458459
raise RuntimeError("Self attention has no KV cache scaling "
459460
"factor attribute!")

vllm/model_executor/models/solar.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -565,8 +565,9 @@ def load_kv_cache_scales(self, quantization_param_path: str) -> None:
565565
# which is consistent with the practice of setting
566566
# scaling_factor = tensor_amax / FPtype_max
567567
scaling_factor *= 2
568-
if hasattr(layer_self_attn, "kv_scale"):
569-
layer_self_attn.attn._kv_scale = scaling_factor
568+
if hasattr(layer_self_attn.attn, "_k_scale"):
569+
layer_self_attn.attn._k_scale = scaling_factor
570+
layer_self_attn.attn._v_scale = scaling_factor
570571
else:
571572
raise RuntimeError("Self attention has no KV cache scaling "
572573
"factor attribute!")

vllm/worker/model_runner.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1136,7 +1136,8 @@ def load_model(self) -> None:
11361136
self.prompt_adapter_manager.create_prompt_adapter_manager(
11371137
self.model))
11381138

1139-
if self.kv_cache_dtype == "fp8" and current_platform.is_rocm():
1139+
if self.kv_cache_dtype == "fp8" and (current_platform.is_rocm()
1140+
or current_platform.is_cuda()):
11401141
# Currently only ROCm accepts kv-cache scaling factors
11411142
# via quantization_param_path and this will be deprecated
11421143
# in the future.

0 commit comments

Comments
 (0)