Skip to content

Commit 55dd119

Browse files
committed
[Frontend][V1] Online serving performance improvements
These help in particular with TTFT, and ITL variance. Overall throughput doesn't change much. - Break up output processing (detokenization) to avoid blocking the event loop for too long - Freeze the heap after startup to reduce GC overhead/pauses - Optimize a couple of CPU hotspots seen during profiling Signed-off-by: Nick Hill <[email protected]>
1 parent 9c485d9 commit 55dd119

File tree

5 files changed

+76
-41
lines changed

5 files changed

+76
-41
lines changed

vllm/entrypoints/openai/api_server.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import atexit
3+
import gc
34
import importlib
45
import inspect
56
import multiprocessing
@@ -104,6 +105,11 @@ async def _force_log():
104105
task.add_done_callback(_running_tasks.remove)
105106
else:
106107
task = None
108+
109+
# Mark the startup heap as static so that it's ignored by GC.
110+
# Reduces pause times of oldest generation collections.
111+
gc.collect()
112+
gc.freeze()
107113
try:
108114
yield
109115
finally:

vllm/entrypoints/openai/protocol.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import re
44
import time
55
from argparse import Namespace
6-
from typing import Any, Dict, List, Literal, Optional, Union
6+
from typing import Any, ClassVar, Dict, List, Literal, Optional, Set, Union
77

88
import torch
99
from pydantic import BaseModel, ConfigDict, Field, model_validator
@@ -42,23 +42,31 @@ class OpenAIBaseModel(BaseModel):
4242
# OpenAI API does allow extra fields
4343
model_config = ConfigDict(extra="allow")
4444

45+
# Cache class field names
46+
field_names: ClassVar[Optional[Set[str]]] = None
47+
4548
@model_validator(mode="before")
4649
@classmethod
4750
def __log_extra_fields__(cls, data):
48-
if isinstance(data, dict):
51+
52+
field_names = cls.field_names
53+
if field_names is None:
54+
if not isinstance(data, dict):
55+
return data
4956
# Get all class field names and their potential aliases
5057
field_names = set()
5158
for field_name, field in cls.model_fields.items():
5259
field_names.add(field_name)
53-
if hasattr(field, 'alias') and field.alias:
54-
field_names.add(field.alias)
55-
56-
# Compare against both field names and aliases
57-
extra_fields = data.keys() - field_names
58-
if extra_fields:
59-
logger.warning(
60-
"The following fields were present in the request "
61-
"but ignored: %s", extra_fields)
60+
if alias := getattr(field, 'alias', None):
61+
field_names.add(alias)
62+
cls.field_names = field_names
63+
64+
# Compare against both field names and aliases
65+
if any(k not in field_names for k in data):
66+
logger.warning(
67+
"The following fields were present in the request "
68+
"but ignored: %s",
69+
data.keys() - field_names)
6270
return data
6371

6472

vllm/v1/engine/async_llm.py

Lines changed: 41 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import asyncio
2+
import math
23
import os
34
from typing import AsyncGenerator, List, Mapping, Optional, Type, Union
45

6+
import numpy as np
7+
58
from vllm.config import ModelConfig, VllmConfig
69
from vllm.engine.arg_utils import AsyncEngineArgs
710
from vllm.engine.protocol import EngineClient
@@ -26,6 +29,11 @@
2629

2730
logger = init_logger(__name__)
2831

32+
# For now determined empirically.
33+
# Larger => higher ITL variance
34+
# Smaller => higher TTFT, throughput impacted
35+
OUTPUT_PROCESSING_CHUNK_SIZE = 128
36+
2937

3038
class AsyncLLM(EngineClient):
3139

@@ -205,17 +213,15 @@ async def generate(
205213

206214
# The output_handler task pushes items into the queue.
207215
# This task pulls from the queue and yields to caller.
208-
while True:
216+
finished = False
217+
while not finished:
209218
# Note: drain queue without await if possible (avoids
210219
# task switching under load which helps performance).
211-
out = q.get_nowait() if q.qsize() > 0 else await q.get()
220+
out = q.get_nowait() if not q.empty() else await q.get()
212221

213222
# Note: both OutputProcessor and EngineCore handle their
214223
# own request cleanup based on finished.
215-
if out.finished:
216-
yield out
217-
break
218-
224+
finished = out.finished
219225
yield out
220226

221227
# If the request is disconnected by the client, the
@@ -233,22 +239,41 @@ async def _run_output_handler(self):
233239
# 1) Pull EngineCoreOutputs from the EngineCore.
234240
outputs = await self.engine_core.get_output_async()
235241

236-
# 2) Process EngineCoreOutputs.
237-
processed_outputs = self.output_processor.process_outputs(
238-
outputs.outputs)
239-
# NOTE: RequestOutputs are pushed to their queues.
240-
assert len(processed_outputs.request_outputs) == 0
241-
242-
# 3) Abort any reqs that finished due to stop strings.
243-
await self.engine_core.abort_requests_async(
244-
processed_outputs.reqs_to_abort)
242+
# Split outputs into chunks of at most
243+
# OUTPUT_PROCESSING_CHUNK_SIZE, so that we don't block the
244+
# event loop for too long.
245+
num_outputs = len(outputs.outputs)
246+
if num_outputs <= OUTPUT_PROCESSING_CHUNK_SIZE:
247+
slices = (outputs.outputs,)
248+
else:
249+
slices = np.array_split(outputs.outputs,
250+
math.ceil(num_outputs / OUTPUT_PROCESSING_CHUNK_SIZE))
251+
252+
iteration_stats = None
253+
for i, outputs_slice in enumerate(slices):
254+
255+
# 2) Process EngineCoreOutputs.
256+
processed_outputs = self.output_processor.process_outputs(
257+
outputs_slice, iteration_stats)
258+
# NOTE: RequestOutputs are pushed to their queues.
259+
assert not processed_outputs.request_outputs
260+
iteration_stats = processed_outputs.iteration_stats
261+
262+
# Allow other asyncio tasks to run between chunks
263+
if i + 1 < len(slices):
264+
await asyncio.sleep(0)
265+
266+
# 3) Abort any reqs that finished due to stop strings.
267+
await self.engine_core.abort_requests_async(
268+
processed_outputs.reqs_to_abort)
245269

246270
# 4) Logging.
247271
# TODO(rob): make into a coroutine and launch it in
248272
# background thread once we add Prometheus.
273+
assert iteration_stats is not None
249274
self._log_stats(
250275
scheduler_stats=outputs.scheduler_stats,
251-
iteration_stats=processed_outputs.iteration_stats,
276+
iteration_stats=iteration_stats,
252277
)
253278

254279
except Exception as e:

vllm/v1/engine/output_processor.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ def add_request(
101101
def process_outputs(
102102
self,
103103
engine_core_outputs: List[EngineCoreOutput],
104+
iteration_stats: Optional[IterationStats] = None,
104105
) -> OutputProcessorOutput:
105106
"""
106107
Process the EngineCoreOutputs:
@@ -133,7 +134,8 @@ def process_outputs(
133134

134135
request_outputs: List[RequestOutput] = []
135136
reqs_to_abort: List[str] = []
136-
iteration_stats = IterationStats(self.log_stats)
137+
if not iteration_stats:
138+
iteration_stats = IterationStats(self.log_stats)
137139
for engine_core_output in engine_core_outputs:
138140
req_id = engine_core_output.request_id
139141
req_state = self.request_states.get(req_id)
@@ -175,8 +177,8 @@ def process_outputs(
175177
iteration_stats=iteration_stats,
176178
)
177179

180+
@staticmethod
178181
def _make_request_output(
179-
self,
180182
request_state: RequestState,
181183
detokenizer_output: Optional[DetokenizerOutput],
182184
) -> Optional[RequestOutput]:

vllm/v1/request.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,12 @@ def __init__(
6464
# recomputing.
6565
self._kv_block_hashes: List[BlockHashType] = []
6666

67+
# Read-only views
68+
# Prevent directly appending to the these lists since
69+
# they should also be updated simultaneously.
70+
self.output_token_ids = ConstantList(self._output_token_ids)
71+
self.all_token_ids = ConstantList(self._all_token_ids)
72+
6773
@classmethod
6874
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
6975
return cls(
@@ -79,18 +85,6 @@ def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
7985
lora_request=request.lora_request,
8086
)
8187

82-
@property
83-
def output_token_ids(self) -> ConstantList[int]:
84-
# Prevent directly appending to the output_token_ids since
85-
# all_token_ids should also be updated simultaneously.
86-
return ConstantList(self._output_token_ids)
87-
88-
@property
89-
def all_token_ids(self) -> ConstantList[int]:
90-
# Prevent directly appending to the all_token_ids since
91-
# output_token_ids should also be updated simultaneously
92-
return ConstantList(self._all_token_ids)
93-
9488
def append_output_token_ids(
9589
self,
9690
token_ids: Union[int, List[int]],

0 commit comments

Comments
 (0)