Skip to content

Commit 47512b3

Browse files
authored
Default to generation_config from model (#12622)
Signed-off-by: Harry Mellor <[email protected]>
1 parent 3b9c6c6 commit 47512b3

File tree

7 files changed

+27
-26
lines changed

7 files changed

+27
-26
lines changed

tests/entrypoints/openai/correctness/test_lmeval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
TASK = "gsm8k"
2121
FILTER = "exact_match,strict-match"
2222
RTOL = 0.03
23-
EXPECTED_VALUE = 0.58
23+
EXPECTED_VALUE = 0.54
2424
DEFAULT_ARGS = ["--max-model-len", "4096", "--disable-log-requests"]
2525
MORE_ARGS_LIST = [
2626
[], # Default

tests/entrypoints/openai/test_serving_chat.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class MockModelConfig:
3838
diff_sampling_param: Optional[dict] = None
3939
allowed_local_media_path: str = ""
4040
encoder_config = None
41+
generation_config: str = "auto"
4142

4243
def get_diff_sampling_param(self):
4344
return self.diff_sampling_param or {}

tests/test_config.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ def test_uses_mrope(model_id, uses_mrope):
289289
def test_generation_config_loading():
290290
model_id = "Qwen/Qwen2.5-1.5B-Instruct"
291291

292-
# When set generation_config to None, the default generation config
292+
# When set generation_config to "vllm", the default generation config
293293
# will not be loaded.
294294
model_config = ModelConfig(model_id,
295295
task="auto",
@@ -298,7 +298,7 @@ def test_generation_config_loading():
298298
trust_remote_code=False,
299299
seed=0,
300300
dtype="float16",
301-
generation_config=None)
301+
generation_config="vllm")
302302
assert model_config.get_diff_sampling_param() == {}
303303

304304
# When set generation_config to "auto", the default generation config
@@ -340,7 +340,7 @@ def test_generation_config_loading():
340340

341341
assert model_config.get_diff_sampling_param() == override_result
342342

343-
# When generation_config is set to None and override_generation_config
343+
# When generation_config is set to "vllm" and override_generation_config
344344
# is set, the override_generation_config should be used directly.
345345
model_config = ModelConfig(
346346
model_id,
@@ -350,7 +350,7 @@ def test_generation_config_loading():
350350
trust_remote_code=False,
351351
seed=0,
352352
dtype="float16",
353-
generation_config=None,
353+
generation_config="vllm",
354354
override_generation_config=override_generation_config)
355355

356356
assert model_config.get_diff_sampling_param() == override_generation_config

vllm/config.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ def __init__(
255255
override_neuron_config: Optional[dict[str, Any]] = None,
256256
override_pooler_config: Optional["PoolerConfig"] = None,
257257
logits_processor_pattern: Optional[str] = None,
258-
generation_config: Optional[str] = None,
258+
generation_config: str = "auto",
259259
enable_sleep_mode: bool = False,
260260
override_generation_config: Optional[dict[str, Any]] = None,
261261
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
@@ -951,7 +951,7 @@ def get_multimodal_config(self) -> "MultiModalConfig":
951951
return self.multimodal_config
952952

953953
def try_get_generation_config(self) -> dict[str, Any]:
954-
if self.generation_config is None or self.generation_config == "auto":
954+
if self.generation_config in ("auto", "vllm"):
955955
config = try_get_generation_config(
956956
self.hf_config_path or self.model,
957957
trust_remote_code=self.trust_remote_code,
@@ -971,17 +971,14 @@ def try_get_generation_config(self) -> dict[str, Any]:
971971
def get_diff_sampling_param(self) -> dict[str, Any]:
972972
"""
973973
This method returns a dictionary containing the parameters
974-
that differ from the default sampling parameters, but only
975-
if `generation_config` is set. If `generation_config` is not
976-
set, an empty dictionary is returned.
974+
that differ from the default sampling parameters. If
975+
`generation_config` is `"vllm"`, an empty dictionary is returned.
977976
978977
Returns:
979978
dict[str, Any]: A dictionary with the differing sampling
980-
parameters if `generation_config` is set, otherwise an
981-
empty dictionary.
979+
parameters, if `generation_config` is `"vllm"` an empty dictionary.
982980
"""
983-
if self.generation_config is None:
984-
# When generation_config is not set
981+
if self.generation_config == "vllm":
985982
config = {}
986983
else:
987984
config = self.try_get_generation_config()

vllm/engine/arg_utils.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ class EngineArgs:
207207

208208
kv_transfer_config: Optional[KVTransferConfig] = None
209209

210-
generation_config: Optional[str] = None
210+
generation_config: Optional[str] = "auto"
211211
override_generation_config: Optional[Dict[str, Any]] = None
212212
enable_sleep_mode: bool = False
213213
model_impl: str = "auto"
@@ -1034,13 +1034,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
10341034
parser.add_argument(
10351035
"--generation-config",
10361036
type=nullable_str,
1037-
default=None,
1037+
default="auto",
10381038
help="The folder path to the generation config. "
1039-
"Defaults to None, no generation config is loaded, vLLM defaults "
1040-
"will be used. If set to 'auto', the generation config will be "
1041-
"loaded from model path. If set to a folder path, the generation "
1042-
"config will be loaded from the specified folder path. If "
1043-
"`max_new_tokens` is specified in generation config, then "
1039+
"Defaults to 'auto', the generation config will be loaded from "
1040+
"model path. If set to 'vllm', no generation config is loaded, "
1041+
"vLLM defaults will be used. If set to a folder path, the "
1042+
"generation config will be loaded from the specified folder path. "
1043+
"If `max_new_tokens` is specified in generation config, then "
10441044
"it sets a server-wide limit on the number of output tokens "
10451045
"for all requests.")
10461046

vllm/entrypoints/openai/serving_chat.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,10 @@ def __init__(
109109
self.default_sampling_params = (
110110
self.model_config.get_diff_sampling_param())
111111
if self.default_sampling_params:
112-
logger.info("Overwriting default chat sampling param with: %s",
113-
self.default_sampling_params)
112+
source = self.model_config.generation_config
113+
source = "model" if source == "auto" else source
114+
logger.info("Using default chat sampling params from %s: %s",
115+
source, self.default_sampling_params)
114116

115117
async def create_chat_completion(
116118
self,

vllm/entrypoints/openai/serving_completion.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,10 @@ def __init__(
5555
self.default_sampling_params = (
5656
self.model_config.get_diff_sampling_param())
5757
if self.default_sampling_params:
58-
logger.info(
59-
"Overwriting default completion sampling param with: %s",
60-
self.default_sampling_params)
58+
source = self.model_config.generation_config
59+
source = "model" if source == "auto" else source
60+
logger.info("Using default completion sampling params from %s: %s",
61+
source, self.default_sampling_params)
6162

6263
async def create_completion(
6364
self,

0 commit comments

Comments
 (0)