Skip to content

Commit 519e8e4

Browse files
authored
[v1] EngineArgs for better config handling for v1 (#10382)
Signed-off-by: rickyx <[email protected]>
1 parent a6760f6 commit 519e8e4

File tree

13 files changed

+109
-27
lines changed

13 files changed

+109
-27
lines changed

Diff for: .buildkite/test-pipeline.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ steps:
172172
- vllm/
173173
- tests/v1
174174
commands:
175-
- pytest -v -s v1
175+
- VLLM_USE_V1=1 pytest -v -s v1
176176

177177
- label: Examples Test # 15min
178178
working_dir: "/vllm-workspace/examples"

Diff for: tests/v1/engine/test_async_llm.py

+3
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ async def generate(engine: AsyncLLM, request_id: str,
3232

3333
@pytest.mark.asyncio
3434
async def test_load(monkeypatch):
35+
# TODO(rickyx): Remove monkeypatch once we have a better way to test V1
36+
# so that in the future when we switch, we don't have to change all the
37+
# tests.
3538
with monkeypatch.context() as m:
3639
m.setenv("VLLM_USE_V1", "1")
3740

Diff for: tests/v1/engine/test_engine_args.py

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import pytest
2+
3+
from vllm import envs
4+
from vllm.config import VllmConfig
5+
from vllm.engine.arg_utils import EngineArgs
6+
from vllm.usage.usage_lib import UsageContext
7+
8+
if not envs.VLLM_USE_V1:
9+
pytest.skip(
10+
"Skipping V1 tests. Rerun with `VLLM_USE_V1=1` to test.",
11+
allow_module_level=True,
12+
)
13+
14+
15+
def test_defaults():
16+
engine_args = EngineArgs(model="facebook/opt-125m")
17+
18+
# Assert V1 defaults
19+
assert (engine_args.enable_prefix_caching
20+
), "V1 turns on prefix caching by default"
21+
22+
23+
def test_defaults_with_usage_context():
24+
engine_args = EngineArgs(model="facebook/opt-125m")
25+
vllm_config: VllmConfig = engine_args.create_engine_config(
26+
UsageContext.LLM_CLASS)
27+
28+
assert vllm_config.scheduler_config.max_num_seqs == 1024
29+
assert vllm_config.scheduler_config.max_num_batched_tokens == 8192
30+
31+
engine_args = EngineArgs(model="facebook/opt-125m")
32+
vllm_config = engine_args.create_engine_config(
33+
UsageContext.OPENAI_API_SERVER)
34+
assert vllm_config.scheduler_config.max_num_seqs == 1024
35+
assert vllm_config.scheduler_config.max_num_batched_tokens == 2048
36+
37+
38+
def test_prefix_cache_disabled_with_multimodel():
39+
engine_args = EngineArgs(model="llava-hf/llava-1.5-7b-hf")
40+
41+
vllm_config = engine_args.create_engine_config(UsageContext.LLM_CLASS)
42+
assert not vllm_config.cache_config.enable_prefix_caching

Diff for: tests/v1/engine/test_engine_core.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ def test_engine_core(monkeypatch):
4343
m.setenv("VLLM_USE_V1", "1")
4444
"""Setup the EngineCore."""
4545
engine_args = EngineArgs(model=MODEL_NAME)
46-
vllm_config = engine_args.create_engine_config()
46+
vllm_config = engine_args.create_engine_config(
47+
usage_context=UsageContext.UNKNOWN_CONTEXT)
4748
executor_class = AsyncLLM._get_executor_cls(vllm_config)
4849

4950
engine_core = EngineCore(vllm_config=vllm_config,

Diff for: tests/v1/engine/test_engine_core_client.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,8 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
8282
m.setenv("VLLM_USE_V1", "1")
8383

8484
engine_args = EngineArgs(model=MODEL_NAME, compilation_config=3)
85-
vllm_config = engine_args.create_engine_config()
85+
vllm_config = engine_args.create_engine_config(
86+
UsageContext.UNKNOWN_CONTEXT)
8687
executor_class = AsyncLLM._get_executor_cls(vllm_config)
8788
client = EngineCoreClient.make_client(
8889
vllm_config,
@@ -153,7 +154,8 @@ async def test_engine_core_client_asyncio(monkeypatch):
153154
m.setenv("VLLM_USE_V1", "1")
154155

155156
engine_args = EngineArgs(model=MODEL_NAME)
156-
vllm_config = engine_args.create_engine_config()
157+
vllm_config = engine_args.create_engine_config(
158+
usage_context=UsageContext.UNKNOWN_CONTEXT)
157159
executor_class = AsyncLLM._get_executor_cls(vllm_config)
158160
client = EngineCoreClient.make_client(
159161
vllm_config,

Diff for: vllm/engine/arg_utils.py

+50-3
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
2121
from vllm.platforms import current_platform
2222
from vllm.transformers_utils.utils import check_gguf_file
23+
from vllm.usage.usage_lib import UsageContext
2324
from vllm.utils import FlexibleArgumentParser, StoreBoolean
2425

2526
if TYPE_CHECKING:
@@ -113,7 +114,7 @@ class EngineArgs:
113114
# NOTE(kzawora): default block size for Gaudi should be 128
114115
# smaller sizes still work, but very inefficiently
115116
block_size: int = 16 if not current_platform.is_hpu() else 128
116-
enable_prefix_caching: bool = False
117+
enable_prefix_caching: Optional[bool] = None
117118
disable_sliding_window: bool = False
118119
use_v2_block_manager: bool = True
119120
swap_space: float = 4 # GiB
@@ -197,6 +198,11 @@ def __post_init__(self):
197198
if not self.tokenizer:
198199
self.tokenizer = self.model
199200

201+
# Override the default value of enable_prefix_caching if it's not set
202+
# by user.
203+
if self.enable_prefix_caching is None:
204+
self.enable_prefix_caching = bool(envs.VLLM_USE_V1)
205+
200206
# support `EngineArgs(compilation_config={...})`
201207
# without having to manually construct a
202208
# CompilationConfig object
@@ -953,7 +959,12 @@ def create_load_config(self) -> LoadConfig:
953959
ignore_patterns=self.ignore_patterns,
954960
)
955961

956-
def create_engine_config(self) -> VllmConfig:
962+
def create_engine_config(self,
963+
usage_context: Optional[UsageContext] = None
964+
) -> VllmConfig:
965+
if envs.VLLM_USE_V1:
966+
self._override_v1_engine_args(usage_context)
967+
957968
# gguf file needs a specific model loader and doesn't use hf_repo
958969
if check_gguf_file(self.model):
959970
self.quantization = self.load_format = "gguf"
@@ -1170,7 +1181,7 @@ def create_engine_config(self) -> VllmConfig:
11701181
or "all" in detailed_trace_modules,
11711182
)
11721183

1173-
return VllmConfig(
1184+
config = VllmConfig(
11741185
model_config=model_config,
11751186
cache_config=cache_config,
11761187
parallel_config=parallel_config,
@@ -1185,6 +1196,42 @@ def create_engine_config(self) -> VllmConfig:
11851196
compilation_config=self.compilation_config,
11861197
)
11871198

1199+
if envs.VLLM_USE_V1:
1200+
self._override_v1_engine_config(config)
1201+
return config
1202+
1203+
def _override_v1_engine_args(self, usage_context: UsageContext) -> None:
1204+
"""
1205+
Override the EngineArgs's args based on the usage context for V1.
1206+
"""
1207+
assert envs.VLLM_USE_V1, "V1 is not enabled"
1208+
1209+
if self.max_num_batched_tokens is None:
1210+
# When no user override, set the default values based on the
1211+
# usage context.
1212+
if usage_context == UsageContext.LLM_CLASS:
1213+
logger.warning("Setting max_num_batched_tokens to 8192 "
1214+
"for LLM_CLASS usage context.")
1215+
self.max_num_seqs = 1024
1216+
self.max_num_batched_tokens = 8192
1217+
elif usage_context == UsageContext.OPENAI_API_SERVER:
1218+
logger.warning("Setting max_num_batched_tokens to 2048 "
1219+
"for OPENAI_API_SERVER usage context.")
1220+
self.max_num_seqs = 1024
1221+
self.max_num_batched_tokens = 2048
1222+
1223+
def _override_v1_engine_config(self, engine_config: VllmConfig) -> None:
1224+
"""
1225+
Override the EngineConfig's configs based on the usage context for V1.
1226+
"""
1227+
assert envs.VLLM_USE_V1, "V1 is not enabled"
1228+
# TODO (ywang96): Enable APC by default when VLM supports it.
1229+
if engine_config.model_config.is_multimodal_model:
1230+
logger.warning(
1231+
"Prefix caching is currently not supported for multimodal "
1232+
"models and has been disabled.")
1233+
engine_config.cache_config.enable_prefix_caching = False
1234+
11881235

11891236
@dataclass
11901237
class AsyncEngineArgs(EngineArgs):

Diff for: vllm/engine/async_llm_engine.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -680,7 +680,7 @@ def from_engine_args(
680680
"""Creates an async LLM engine from the engine arguments."""
681681
# Create the engine configs.
682682
if engine_config is None:
683-
engine_config = engine_args.create_engine_config()
683+
engine_config = engine_args.create_engine_config(usage_context)
684684

685685
executor_class = cls._get_executor_cls(engine_config)
686686

Diff for: vllm/engine/llm_engine.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -568,7 +568,7 @@ def from_engine_args(
568568
) -> "LLMEngine":
569569
"""Creates an LLM engine from the engine arguments."""
570570
# Create the engine configs.
571-
engine_config = engine_args.create_engine_config()
571+
engine_config = engine_args.create_engine_config(usage_context)
572572
executor_class = cls._get_executor_cls(engine_config)
573573
# Create the LLM engine.
574574
engine = cls(

Diff for: vllm/engine/multiprocessing/engine.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def from_engine_args(cls, engine_args: AsyncEngineArgs,
111111
from vllm.plugins import load_general_plugins
112112
load_general_plugins()
113113

114-
engine_config = engine_args.create_engine_config()
114+
engine_config = engine_args.create_engine_config(usage_context)
115115
executor_class = LLMEngine._get_executor_cls(engine_config)
116116

117117
use_async_sockets = engine_config.model_config.use_async_output_proc

Diff for: vllm/entrypoints/openai/api_server.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,8 @@ async def build_async_engine_client_from_engine_args(
135135
# TODO: fill out feature matrix.
136136
if (MQLLMEngineClient.is_unsupported_config(engine_args)
137137
or envs.VLLM_USE_V1 or disable_frontend_multiprocessing):
138-
139-
engine_config = engine_args.create_engine_config()
138+
engine_config = engine_args.create_engine_config(
139+
UsageContext.OPENAI_API_SERVER)
140140
uses_ray = getattr(AsyncLLMEngine._get_executor_cls(engine_config),
141141
"uses_ray", False)
142142

Diff for: vllm/v1/engine/async_llm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def from_engine_args(
9494

9595
# Create the engine configs.
9696
if engine_config is None:
97-
vllm_config = engine_args.create_engine_config()
97+
vllm_config = engine_args.create_engine_config(usage_context)
9898
else:
9999
vllm_config = engine_config
100100

Diff for: vllm/v1/engine/core.py

-13
Original file line numberDiff line numberDiff line change
@@ -41,19 +41,6 @@ def __init__(
4141
executor_class: Type[GPUExecutor],
4242
usage_context: UsageContext,
4343
):
44-
# Override the configs for V1.
45-
# FIXME
46-
if usage_context == UsageContext.LLM_CLASS:
47-
vllm_config.scheduler_config.max_num_seqs = 1024
48-
vllm_config.scheduler_config.max_num_batched_tokens = 8192
49-
elif usage_context == UsageContext.OPENAI_API_SERVER:
50-
vllm_config.scheduler_config.max_num_seqs = 1024
51-
vllm_config.scheduler_config.max_num_batched_tokens = 2048
52-
53-
# TODO (ywang96): Enable APC by default when VLM supports it.
54-
if not vllm_config.model_config.is_multimodal_model:
55-
vllm_config.cache_config.enable_prefix_caching = True
56-
5744
assert vllm_config.model_config.task != "embedding"
5845

5946
logger.info("Initializing an LLM engine (v%s) with config: %s",

Diff for: vllm/v1/engine/llm_engine.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def from_engine_args(
8282
"""Creates an LLM engine from the engine arguments."""
8383

8484
# Create the engine configs.
85-
vllm_config = engine_args.create_engine_config()
85+
vllm_config = engine_args.create_engine_config(usage_context)
8686
executor_class = cls._get_executor_cls(vllm_config)
8787

8888
if VLLM_ENABLE_V1_MULTIPROCESSING:

0 commit comments

Comments
 (0)