Skip to content

Default to generation_config from model #12622

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Mar 8, 2025
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 6 additions & 9 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@
override_neuron_config: Optional[Dict[str, Any]] = None,
override_pooler_config: Optional["PoolerConfig"] = None,
logits_processor_pattern: Optional[str] = None,
generation_config: Optional[str] = None,
generation_config: Optional[str] = "auto",
enable_sleep_mode: bool = False,
override_generation_config: Optional[Dict[str, Any]] = None,
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
Expand Down Expand Up @@ -950,7 +950,7 @@
return self.multimodal_config

def try_get_generation_config(self) -> Dict[str, Any]:
if self.generation_config is None or self.generation_config == "auto":
if self.generation_config in ("auto", "vllm"):
config = try_get_generation_config(
self.hf_config_path or self.model,
trust_remote_code=self.trust_remote_code,
Expand All @@ -958,7 +958,7 @@
)
else:
config = try_get_generation_config(
self.generation_config,

Check failure on line 961 in vllm/config.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument 1 to "try_get_generation_config" has incompatible type "str | None"; expected "str" [arg-type]

Check failure on line 961 in vllm/config.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument 1 to "try_get_generation_config" has incompatible type "str | None"; expected "str" [arg-type]

Check failure on line 961 in vllm/config.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument 1 to "try_get_generation_config" has incompatible type "str | None"; expected "str" [arg-type]

Check failure on line 961 in vllm/config.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument 1 to "try_get_generation_config" has incompatible type "str | None"; expected "str" [arg-type]

Check failure on line 961 in vllm/config.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument 1 to "try_get_generation_config" has incompatible type "str | None"; expected "str" [arg-type]

Check failure on line 961 in vllm/config.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument 1 to "try_get_generation_config" has incompatible type "str | None"; expected "str" [arg-type]

Check failure on line 961 in vllm/config.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument 1 to "try_get_generation_config" has incompatible type "Optional[str]"; expected "str" [arg-type]

Check failure on line 961 in vllm/config.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument 1 to "try_get_generation_config" has incompatible type "Optional[str]"; expected "str" [arg-type]

Check failure on line 961 in vllm/config.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument 1 to "try_get_generation_config" has incompatible type "Optional[str]"; expected "str" [arg-type]

Check failure on line 961 in vllm/config.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument 1 to "try_get_generation_config" has incompatible type "Optional[str]"; expected "str" [arg-type]
trust_remote_code=self.trust_remote_code,
)

Expand All @@ -970,17 +970,14 @@
def get_diff_sampling_param(self) -> Dict[str, Any]:
"""
This method returns a dictionary containing the parameters
that differ from the default sampling parameters, but only
if `generation_config` is set. If `generation_config` is not
set, an empty dictionary is returned.
that differ from the default sampling parameters. If
`generation_config` is `"vllm"`, an empty dictionary is returned.

Returns:
Dict[str, Any]: A dictionary with the differing sampling
parameters if `generation_config` is set, otherwise an
empty dictionary.
parameters, if `generation_config` is `"vllm"` an empty dictionary.
"""
if self.generation_config is None:
# When generation_config is not set
if self.generation_config == "vllm":
config = {}
else:
config = self.try_get_generation_config()
Expand Down
12 changes: 6 additions & 6 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1016,13 +1016,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser.add_argument(
"--generation-config",
type=nullable_str,
default=None,
default="auto",
help="The folder path to the generation config. "
"Defaults to None, no generation config is loaded, vLLM defaults "
"will be used. If set to 'auto', the generation config will be "
"loaded from model path. If set to a folder path, the generation "
"config will be loaded from the specified folder path. If "
"`max_new_tokens` is specified in generation config, then "
"Defaults to 'auto', the generation config will be loaded from "
"model path. If set to 'vllm', no generation config is loaded, "
"vLLM defaults will be used. If set to a folder path, the "
"generation config will be loaded from the specified folder path. "
"If `max_new_tokens` is specified in generation config, then "
"it sets a server-wide limit on the number of output tokens "
"for all requests.")

Expand Down
4 changes: 4 additions & 0 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,10 @@ def get_default_sampling_params(self) -> SamplingParams:
diff_sampling_param = (
self.llm_engine.model_config.get_diff_sampling_param())
if diff_sampling_param:
source = self.llm_engine.model_config.generation_config
source = "model" if source == "auto" else source
logger.info("Using default sampling params from %s: %s", source,
diff_sampling_param)
return SamplingParams.from_optional(**diff_sampling_param)
return SamplingParams()

Expand Down
6 changes: 4 additions & 2 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,10 @@ def __init__(
self.enable_prompt_tokens_details = enable_prompt_tokens_details
diff_sampling_param = self.model_config.get_diff_sampling_param()
if diff_sampling_param:
logger.info("Overwriting default chat sampling param with: %s",
diff_sampling_param)
source = self.model_config.generation_config
source = "model" if source == "auto" else source
logger.info("Using default chat sampling params from %s: %s",
source, diff_sampling_param)

async def create_chat_completion(
self,
Expand Down
7 changes: 4 additions & 3 deletions vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,10 @@ def __init__(
return_tokens_as_token_ids=return_tokens_as_token_ids)
diff_sampling_param = self.model_config.get_diff_sampling_param()
if diff_sampling_param:
logger.info(
"Overwriting default completion sampling param with: %s",
diff_sampling_param)
source = self.model_config.generation_config
source = "model" if source == "auto" else source
logger.info("Using default completion sampling params from %s: %s",
source, diff_sampling_param)

async def create_completion(
self,
Expand Down