Skip to content

Commit d7a090a

Browse files
njhillmzusman
authored andcommitted
[Frontend][V1] Online serving performance improvements (vllm-project#12287)
1 parent 4c05cba commit d7a090a

File tree

7 files changed

+101
-45
lines changed

7 files changed

+101
-45
lines changed

vllm/entrypoints/openai/api_server.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import atexit
3+
import gc
34
import importlib
45
import inspect
56
import multiprocessing
@@ -104,6 +105,11 @@ async def _force_log():
104105
task.add_done_callback(_running_tasks.remove)
105106
else:
106107
task = None
108+
109+
# Mark the startup heap as static so that it's ignored by GC.
110+
# Reduces pause times of oldest generation collections.
111+
gc.collect()
112+
gc.freeze()
107113
try:
108114
yield
109115
finally:

vllm/entrypoints/openai/protocol.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import re
44
import time
55
from argparse import Namespace
6-
from typing import Any, Dict, List, Literal, Optional, Union
6+
from typing import Any, ClassVar, Dict, List, Literal, Optional, Set, Union
77

88
import torch
99
from pydantic import BaseModel, ConfigDict, Field, model_validator
@@ -42,23 +42,31 @@ class OpenAIBaseModel(BaseModel):
4242
# OpenAI API does allow extra fields
4343
model_config = ConfigDict(extra="allow")
4444

45+
# Cache class field names
46+
field_names: ClassVar[Optional[Set[str]]] = None
47+
4548
@model_validator(mode="before")
4649
@classmethod
4750
def __log_extra_fields__(cls, data):
48-
if isinstance(data, dict):
51+
52+
field_names = cls.field_names
53+
if field_names is None:
54+
if not isinstance(data, dict):
55+
return data
4956
# Get all class field names and their potential aliases
5057
field_names = set()
5158
for field_name, field in cls.model_fields.items():
5259
field_names.add(field_name)
53-
if hasattr(field, 'alias') and field.alias:
54-
field_names.add(field.alias)
55-
56-
# Compare against both field names and aliases
57-
extra_fields = data.keys() - field_names
58-
if extra_fields:
59-
logger.warning(
60-
"The following fields were present in the request "
61-
"but ignored: %s", extra_fields)
60+
if alias := getattr(field, 'alias', None):
61+
field_names.add(alias)
62+
cls.field_names = field_names
63+
64+
# Compare against both field names and aliases
65+
if any(k not in field_names for k in data):
66+
logger.warning(
67+
"The following fields were present in the request "
68+
"but ignored: %s",
69+
data.keys() - field_names)
6270
return data
6371

6472

vllm/envs.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
7474
VLLM_DISABLE_COMPILE_CACHE: bool = False
7575
VLLM_SERVER_DEV_MODE: bool = False
76+
VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: int = 128
7677

7778

7879
def get_default_cache_root():
@@ -474,6 +475,16 @@ def get_default_config_root():
474475
# e.g. `/reset_prefix_cache`
475476
"VLLM_SERVER_DEV_MODE":
476477
lambda: bool(int(os.getenv("VLLM_SERVER_DEV_MODE", "0"))),
478+
479+
# Controls the maximum number of requests to handle in a
480+
# single asyncio task when processing per-token outputs in the
481+
# V1 AsyncLLM interface. It is applicable when handling a high
482+
# concurrency of streaming requests.
483+
# Setting this too high can result in a higher variance of
484+
# inter-message latencies. Setting it too low can negatively impact
485+
# TTFT and overall throughput.
486+
"VLLM_V1_OUTPUT_PROC_CHUNK_SIZE":
487+
lambda: int(os.getenv("VLLM_V1_OUTPUT_PROC_CHUNK_SIZE", "128")),
477488
}
478489

479490
# end-env-vars-definition

vllm/v1/engine/async_llm.py

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,12 @@
22
import os
33
from typing import AsyncGenerator, List, Mapping, Optional, Type, Union
44

5+
import numpy as np
6+
57
from vllm.config import ModelConfig, VllmConfig
68
from vllm.engine.arg_utils import AsyncEngineArgs
79
from vllm.engine.protocol import EngineClient
10+
from vllm.envs import VLLM_V1_OUTPUT_PROC_CHUNK_SIZE
811
from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType
912
from vllm.inputs.preprocess import InputPreprocessor
1013
from vllm.logger import init_logger
@@ -16,7 +19,7 @@
1619
from vllm.transformers_utils.tokenizer import AnyTokenizer
1720
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
1821
from vllm.usage.usage_lib import UsageContext
19-
from vllm.utils import kill_process_tree
22+
from vllm.utils import cdiv, kill_process_tree
2023
from vllm.v1.engine.core_client import EngineCoreClient
2124
from vllm.v1.engine.output_processor import OutputProcessor
2225
from vllm.v1.engine.processor import Processor
@@ -205,17 +208,15 @@ async def generate(
205208

206209
# The output_handler task pushes items into the queue.
207210
# This task pulls from the queue and yields to caller.
208-
while True:
211+
finished = False
212+
while not finished:
209213
# Note: drain queue without await if possible (avoids
210214
# task switching under load which helps performance).
211-
out = q.get_nowait() if q.qsize() > 0 else await q.get()
215+
out = q.get_nowait() if not q.empty() else await q.get()
212216

213217
# Note: both OutputProcessor and EngineCore handle their
214218
# own request cleanup based on finished.
215-
if out.finished:
216-
yield out
217-
break
218-
219+
finished = out.finished
219220
yield out
220221

221222
# If the request is disconnected by the client, the
@@ -233,22 +234,41 @@ async def _run_output_handler(self):
233234
# 1) Pull EngineCoreOutputs from the EngineCore.
234235
outputs = await self.engine_core.get_output_async()
235236

236-
# 2) Process EngineCoreOutputs.
237-
processed_outputs = self.output_processor.process_outputs(
238-
outputs.outputs)
239-
# NOTE: RequestOutputs are pushed to their queues.
240-
assert len(processed_outputs.request_outputs) == 0
241-
242-
# 3) Abort any reqs that finished due to stop strings.
243-
await self.engine_core.abort_requests_async(
244-
processed_outputs.reqs_to_abort)
237+
# Split outputs into chunks of at most
238+
# VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the
239+
# event loop for too long.
240+
num_outputs = len(outputs.outputs)
241+
if num_outputs <= VLLM_V1_OUTPUT_PROC_CHUNK_SIZE:
242+
slices = (outputs.outputs, )
243+
else:
244+
slices = np.array_split(
245+
outputs.outputs,
246+
cdiv(num_outputs, VLLM_V1_OUTPUT_PROC_CHUNK_SIZE))
247+
248+
iteration_stats = None
249+
for i, outputs_slice in enumerate(slices):
250+
# 2) Process EngineCoreOutputs.
251+
processed_outputs = self.output_processor.process_outputs(
252+
outputs_slice, iteration_stats)
253+
# NOTE: RequestOutputs are pushed to their queues.
254+
assert not processed_outputs.request_outputs
255+
iteration_stats = processed_outputs.iteration_stats
256+
257+
# Allow other asyncio tasks to run between chunks
258+
if i + 1 < len(slices):
259+
await asyncio.sleep(0)
260+
261+
# 3) Abort any reqs that finished due to stop strings.
262+
await self.engine_core.abort_requests_async(
263+
processed_outputs.reqs_to_abort)
245264

246265
# 4) Logging.
247266
# TODO(rob): make into a coroutine and launch it in
248267
# background thread once we add Prometheus.
268+
assert iteration_stats is not None
249269
self._log_stats(
250270
scheduler_stats=outputs.scheduler_stats,
251-
iteration_stats=processed_outputs.iteration_stats,
271+
iteration_stats=iteration_stats,
252272
)
253273

254274
except Exception as e:

vllm/v1/engine/core_client.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
import asyncio
12
import os
23
import signal
34
import weakref
45
from abc import ABC, abstractmethod
5-
from typing import List, Type
6+
from typing import List, Optional, Type
67

78
import msgspec
89
import zmq
@@ -255,10 +256,24 @@ def __init__(self, vllm_config: VllmConfig,
255256
log_stats=True,
256257
)
257258

259+
self.outputs_queue: Optional[asyncio.Queue[bytes]] = None
260+
self.queue_task: Optional[asyncio.Task] = None
261+
258262
async def get_output_async(self) -> EngineCoreOutputs:
263+
if self.outputs_queue is None:
264+
# Perform IO in separate task to parallelize as much as possible
265+
self.outputs_queue = asyncio.Queue()
266+
267+
async def process_outputs_socket():
268+
assert self.outputs_queue is not None
269+
while True:
270+
(frame, ) = await self.output_socket.recv_multipart(
271+
copy=False)
272+
self.outputs_queue.put_nowait(frame.buffer)
273+
274+
self.queue_task = asyncio.create_task(process_outputs_socket())
259275

260-
frames = await self.output_socket.recv_multipart(copy=False)
261-
return self.decoder.decode(frames[0].buffer)
276+
return self.decoder.decode(await self.outputs_queue.get())
262277

263278
async def _send_input(self, request_type: EngineCoreRequestType,
264279
request: EngineCoreRequestUnion) -> None:

vllm/v1/engine/output_processor.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ def add_request(
101101
def process_outputs(
102102
self,
103103
engine_core_outputs: List[EngineCoreOutput],
104+
iteration_stats: Optional[IterationStats] = None,
104105
) -> OutputProcessorOutput:
105106
"""
106107
Process the EngineCoreOutputs:
@@ -133,7 +134,8 @@ def process_outputs(
133134

134135
request_outputs: List[RequestOutput] = []
135136
reqs_to_abort: List[str] = []
136-
iteration_stats = IterationStats(self.log_stats)
137+
if not iteration_stats:
138+
iteration_stats = IterationStats(self.log_stats)
137139
for engine_core_output in engine_core_outputs:
138140
req_id = engine_core_output.request_id
139141
req_state = self.request_states.get(req_id)
@@ -175,8 +177,8 @@ def process_outputs(
175177
iteration_stats=iteration_stats,
176178
)
177179

180+
@staticmethod
178181
def _make_request_output(
179-
self,
180182
request_state: RequestState,
181183
detokenizer_output: Optional[DetokenizerOutput],
182184
) -> Optional[RequestOutput]:

vllm/v1/request.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,12 @@ def __init__(
6464
# recomputing.
6565
self._kv_block_hashes: List[BlockHashType] = []
6666

67+
# Read-only views
68+
# Prevent directly appending to the these lists since
69+
# they should also be updated simultaneously.
70+
self.output_token_ids = ConstantList(self._output_token_ids)
71+
self.all_token_ids = ConstantList(self._all_token_ids)
72+
6773
@classmethod
6874
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
6975
return cls(
@@ -79,18 +85,6 @@ def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
7985
lora_request=request.lora_request,
8086
)
8187

82-
@property
83-
def output_token_ids(self) -> ConstantList[int]:
84-
# Prevent directly appending to the output_token_ids since
85-
# all_token_ids should also be updated simultaneously.
86-
return ConstantList(self._output_token_ids)
87-
88-
@property
89-
def all_token_ids(self) -> ConstantList[int]:
90-
# Prevent directly appending to the all_token_ids since
91-
# output_token_ids should also be updated simultaneously
92-
return ConstantList(self._all_token_ids)
93-
9488
def append_output_token_ids(
9589
self,
9690
token_ids: Union[int, List[int]],

0 commit comments

Comments
 (0)