Skip to content

Commit c4dfb6e

Browse files
wayzengshreyankg
authored andcommitted
[Feature] specify model in config.yaml (vllm-project#14855)
Signed-off-by: weizeng <[email protected]>
1 parent f1dd743 commit c4dfb6e

File tree

7 files changed

+102
-30
lines changed

7 files changed

+102
-30
lines changed

docs/source/serving/openai_compatible_server.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ For example:
184184
```yaml
185185
# config.yaml
186186

187+
model: meta-llama/Llama-3.1-8B-Instruct
187188
host: "127.0.0.1"
188189
port: 6379
189190
uvicorn-log-level: "info"
@@ -192,12 +193,13 @@ uvicorn-log-level: "info"
192193
To use the above config file:
193194
194195
```bash
195-
vllm serve SOME_MODEL --config config.yaml
196+
vllm serve --config config.yaml
196197
```
197198

198199
:::{note}
199200
In case an argument is supplied simultaneously using command line and the config file, the value from the command line will take precedence.
200201
The order of priorities is `command line > config file values > defaults`.
202+
e.g. `vllm serve SOME_MODEL --config config.yaml`, SOME_MODEL takes precedence over `model` in config file.
201203
:::
202204

203205
## API Reference
File renamed without changes.
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Same as test_config.yaml but with model specified
2+
model: config-model
3+
port: 12312
4+
served_model_name: mymodel
5+
tensor_parallel_size: 2
6+
trust_remote_code: true
7+
multi_step_stream_outputs: false

tests/conftest.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1121,3 +1121,15 @@ def pytest_collection_modifyitems(config, items):
11211121
for item in items:
11221122
if "optional" in item.keywords:
11231123
item.add_marker(skip_optional)
1124+
1125+
1126+
@pytest.fixture(scope="session")
1127+
def cli_config_file():
1128+
"""Return the path to the CLI config file."""
1129+
return os.path.join(_TEST_DIR, "config", "test_config.yaml")
1130+
1131+
1132+
@pytest.fixture(scope="session")
1133+
def cli_config_file_with_model():
1134+
"""Return the path to the CLI config file with model."""
1135+
return os.path.join(_TEST_DIR, "config", "test_config_with_model.yaml")

tests/test_utils.py

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import pytest
1010
import torch
11-
from vllm_test_utils import monitor
11+
from vllm_test_utils.monitor import monitor
1212

1313
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
1414
from vllm.utils import (FlexibleArgumentParser, MemorySnapshot,
@@ -140,7 +140,8 @@ def parser():
140140
def parser_with_config():
141141
parser = FlexibleArgumentParser()
142142
parser.add_argument('serve')
143-
parser.add_argument('model_tag')
143+
parser.add_argument('model_tag', nargs='?')
144+
parser.add_argument('--model', type=str)
144145
parser.add_argument('--served-model-name', type=str)
145146
parser.add_argument('--config', type=str)
146147
parser.add_argument('--port', type=int)
@@ -196,29 +197,29 @@ def test_missing_required_argument(parser):
196197
parser.parse_args([])
197198

198199

199-
def test_cli_override_to_config(parser_with_config):
200+
def test_cli_override_to_config(parser_with_config, cli_config_file):
200201
args = parser_with_config.parse_args([
201-
'serve', 'mymodel', '--config', './data/test_config.yaml',
202+
'serve', 'mymodel', '--config', cli_config_file,
202203
'--tensor-parallel-size', '3'
203204
])
204205
assert args.tensor_parallel_size == 3
205206
args = parser_with_config.parse_args([
206207
'serve', 'mymodel', '--tensor-parallel-size', '3', '--config',
207-
'./data/test_config.yaml'
208+
cli_config_file
208209
])
209210
assert args.tensor_parallel_size == 3
210211
assert args.port == 12312
211212
args = parser_with_config.parse_args([
212213
'serve', 'mymodel', '--tensor-parallel-size', '3', '--config',
213-
'./data/test_config.yaml', '--port', '666'
214+
cli_config_file, '--port', '666'
214215
])
215216
assert args.tensor_parallel_size == 3
216217
assert args.port == 666
217218

218219

219-
def test_config_args(parser_with_config):
220+
def test_config_args(parser_with_config, cli_config_file):
220221
args = parser_with_config.parse_args(
221-
['serve', 'mymodel', '--config', './data/test_config.yaml'])
222+
['serve', 'mymodel', '--config', cli_config_file])
222223
assert args.tensor_parallel_size == 2
223224
assert args.trust_remote_code
224225
assert not args.multi_step_stream_outputs
@@ -240,10 +241,9 @@ def test_config_file(parser_with_config):
240241
])
241242

242243

243-
def test_no_model_tag(parser_with_config):
244+
def test_no_model_tag(parser_with_config, cli_config_file):
244245
with pytest.raises(ValueError):
245-
parser_with_config.parse_args(
246-
['serve', '--config', './data/test_config.yaml'])
246+
parser_with_config.parse_args(['serve', '--config', cli_config_file])
247247

248248

249249
# yapf: enable
@@ -476,3 +476,34 @@ def test_swap_dict_values(obj, key1, key2):
476476
assert obj[key1] == original_obj[key2]
477477
else:
478478
assert key1 not in obj
479+
480+
481+
def test_model_specification(parser_with_config,
482+
cli_config_file,
483+
cli_config_file_with_model):
484+
# Test model in CLI takes precedence over config
485+
args = parser_with_config.parse_args([
486+
'serve', 'cli-model', '--config', cli_config_file_with_model
487+
])
488+
assert args.model_tag == 'cli-model'
489+
assert args.served_model_name == 'mymodel'
490+
491+
# Test model from config file works
492+
args = parser_with_config.parse_args([
493+
'serve', '--config', cli_config_file_with_model
494+
])
495+
assert args.model == 'config-model'
496+
assert args.served_model_name == 'mymodel'
497+
498+
# Test no model specified anywhere raises error
499+
with pytest.raises(ValueError, match="No model specified!"):
500+
parser_with_config.parse_args(['serve', '--config', cli_config_file])
501+
502+
# Test other config values are preserved
503+
args = parser_with_config.parse_args([
504+
'serve', 'cli-model', '--config', cli_config_file_with_model
505+
])
506+
assert args.tensor_parallel_size == 2
507+
assert args.trust_remote_code is True
508+
assert args.multi_step_stream_outputs is False
509+
assert args.port == 12312

vllm/entrypoints/cli/serve.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,16 @@ def __init__(self):
2121

2222
@staticmethod
2323
def cmd(args: argparse.Namespace) -> None:
24-
# The default value of `--model`
25-
if args.model != EngineArgs.model:
26-
raise ValueError(
27-
"With `vllm serve`, you should provide the model as a "
28-
"positional argument instead of via the `--model` option.")
24+
# If model is specified in CLI (as positional arg), it takes precedence
25+
if hasattr(args, 'model_tag') and args.model_tag is not None:
26+
args.model = args.model_tag
27+
# Otherwise use model from config (already in args.model)
2928

30-
# EngineArgs expects the model name to be passed as --model.
31-
args.model = args.model_tag
29+
# Check if we have a model specified somewhere
30+
if args.model == EngineArgs.model: # Still has default value
31+
raise ValueError(
32+
"With `vllm serve`, you should provide the model either as a "
33+
"positional argument or in config file.")
3234

3335
uvloop.run(run_server(args))
3436

@@ -41,10 +43,12 @@ def subparser_init(
4143
serve_parser = subparsers.add_parser(
4244
"serve",
4345
help="Start the vLLM OpenAI Compatible API server",
44-
usage="vllm serve <model_tag> [options]")
46+
usage="vllm serve [model_tag] [options]")
4547
serve_parser.add_argument("model_tag",
4648
type=str,
47-
help="The model tag to serve")
49+
nargs='?',
50+
help="The model tag to serve "
51+
"(optional if specified in config)")
4852
serve_parser.add_argument(
4953
"--config",
5054
type=str,

vllm/utils.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1264,19 +1264,29 @@ def _pull_args_from_config(self, args: list[str]) -> list[str]:
12641264
config_args = self._load_config_file(file_path)
12651265

12661266
# 0th index is for {serve,chat,complete}
1267-
# followed by model_tag (only for serve)
1267+
# optionally followed by model_tag (only for serve)
12681268
# followed by config args
12691269
# followed by rest of cli args.
12701270
# maintaining this order will enforce the precedence
12711271
# of cli > config > defaults
12721272
if args[0] == "serve":
1273-
if index == 1:
1273+
model_in_cli = len(args) > 1 and not args[1].startswith('-')
1274+
model_in_config = any(arg == '--model' for arg in config_args)
1275+
1276+
if not model_in_cli and not model_in_config:
12741277
raise ValueError(
1275-
"No model_tag specified! Please check your command-line"
1276-
" arguments.")
1277-
args = [args[0]] + [
1278-
args[1]
1279-
] + config_args + args[2:index] + args[index + 2:]
1278+
"No model specified! Please specify model either in "
1279+
"command-line arguments or in config file.")
1280+
1281+
if model_in_cli:
1282+
# Model specified as positional arg, keep CLI version
1283+
args = [args[0]] + [
1284+
args[1]
1285+
] + config_args + args[2:index] + args[index + 2:]
1286+
else:
1287+
# No model in CLI, use config if available
1288+
args = [args[0]
1289+
] + config_args + args[1:index] + args[index + 2:]
12801290
else:
12811291
args = [args[0]] + config_args + args[1:index] + args[index + 2:]
12821292

@@ -1294,9 +1304,7 @@ def _load_config_file(self, file_path: str) -> list[str]:
12941304
'--port': '12323',
12951305
'--tensor-parallel-size': '4'
12961306
]
1297-
12981307
"""
1299-
13001308
extension: str = file_path.split('.')[-1]
13011309
if extension not in ('yaml', 'yml'):
13021310
raise ValueError(
@@ -1321,7 +1329,15 @@ def _load_config_file(self, file_path: str) -> list[str]:
13211329
if isinstance(action, StoreBoolean)
13221330
]
13231331

1332+
# Skip model from config if it's provided as positional argument
1333+
skip_model = (hasattr(self, '_parsed_args') and self._parsed_args
1334+
and len(self._parsed_args) > 1
1335+
and self._parsed_args[0] == 'serve'
1336+
and not self._parsed_args[1].startswith('-'))
1337+
13241338
for key, value in config.items():
1339+
if skip_model and key == 'model':
1340+
continue
13251341
if isinstance(value, bool) and key not in store_boolean_arguments:
13261342
if value:
13271343
processed_args.append('--' + key)

0 commit comments

Comments
 (0)