Skip to content

Commit 080a4bf

Browse files
authored
add sync_openai api_server (#365)
1 parent d9385b4 commit 080a4bf

File tree

3 files changed

+558
-0
lines changed

3 files changed

+558
-0
lines changed

vllm/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from vllm.engine.async_llm_engine import AsyncLLMEngine
55
from vllm.engine.llm_engine import LLMEngine
66
from vllm.entrypoints.llm import LLM
7+
from vllm.entrypoints.fast_sync_llm import FastSyncLLM
78
from vllm.executor.ray_utils import initialize_ray_cluster
89
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
910
from vllm.model_executor.models import ModelRegistry
@@ -21,6 +22,7 @@
2122
"__version__",
2223
"__version_tuple__",
2324
"LLM",
25+
"FastSyncLLM",
2426
"ModelRegistry",
2527
"PromptType",
2628
"TextPrompt",

vllm/entrypoints/fast_sync_llm.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
import multiprocessing as mp
2+
from queue import Empty
3+
from typing import Union
4+
5+
import vllm.envs as envs
6+
from vllm.distributed.communication_op import broadcast_tensor_dict
7+
from vllm.engine.arg_utils import EngineArgs
8+
from vllm.engine.llm_engine import LLMEngine
9+
from vllm.executor.multiproc_gpu_executor import MultiprocessingGPUExecutor
10+
from vllm.executor.ray_gpu_executor import RayGPUExecutor
11+
from vllm.inputs import PromptType, TokensPrompt
12+
from vllm.logger import init_logger
13+
from vllm.pooling_params import PoolingParams
14+
from vllm.sampling_params import SamplingParams
15+
from vllm.usage.usage_lib import UsageContext
16+
from vllm.utils import Counter
17+
18+
logger = init_logger(__name__)
19+
20+
21+
class FastSyncLLM:
22+
23+
def __init__(
24+
self,
25+
engine_args: EngineArgs,
26+
input_queue: mp.Queue,
27+
result_queue: mp.Queue,
28+
**kwargs,
29+
) -> None:
30+
if "disable_log_stats" not in kwargs:
31+
kwargs["disable_log_stats"] = True
32+
self.engine_args = engine_args
33+
self.request_counter = Counter()
34+
35+
self.input_queue = input_queue
36+
self.result_queue = result_queue
37+
self.finish = False
38+
self.need_restart = False
39+
self.llm_engine: LLMEngine
40+
41+
def _add_request(
42+
self,
43+
inputs: PromptType,
44+
params: Union[SamplingParams, PoolingParams],
45+
request_id: str,
46+
) -> None:
47+
if isinstance(inputs, list):
48+
inputs = TokensPrompt(prompt_token_ids=inputs)
49+
self.llm_engine.add_request(request_id, inputs, params)
50+
51+
def _poll_requests(self):
52+
while True:
53+
if not self.llm_engine.has_unfinished_requests():
54+
logger.info("No unfinished requests. Waiting...")
55+
(request_id, prompt, sampling_params) = self.input_queue.get()
56+
if self.need_restart and isinstance(
57+
self.llm_engine.model_executor,
58+
MultiprocessingGPUExecutor):
59+
logger.info("Restarting worker loops")
60+
for worker in self.llm_engine.model_executor.workers:
61+
worker.execute_method("start_worker_execution_loop")
62+
self.need_restart = False
63+
64+
else:
65+
try:
66+
(request_id, prompt,
67+
sampling_params) = self.input_queue.get_nowait()
68+
except Empty:
69+
break
70+
self._add_request(prompt, sampling_params, request_id)
71+
72+
def run_engine(self):
73+
self.llm_engine = LLMEngine.from_engine_args(
74+
self.engine_args, usage_context=UsageContext.LLM_CLASS)
75+
assert not isinstance(
76+
self.llm_engine.model_executor,
77+
RayGPUExecutor), "Ray is not supported in sync openai mode"
78+
79+
self.result_queue.put(("Ready", None, None))
80+
prompt_lens = {}
81+
tokens = {} # type: ignore
82+
log_interval = 100
83+
poll_interval = envs.VLLM_SYNC_SERVER_ENGINE_STEPS_BETWEEN_POLLS
84+
try:
85+
while True:
86+
poll_interval -= 1
87+
if (self.input_queue.qsize() >=
88+
envs.VLLM_SYNC_SERVER_ACCUM_REQUESTS
89+
or poll_interval <= 0
90+
or not self.llm_engine.has_unfinished_requests()):
91+
self._poll_requests()
92+
poll_interval = \
93+
envs.VLLM_SYNC_SERVER_ENGINE_STEPS_BETWEEN_POLLS
94+
step_outputs = self.llm_engine.step()
95+
log_interval -= 1
96+
if log_interval == 0:
97+
log_interval = 100
98+
logger.info("Step finished. Unfinished requests: %d",
99+
self.llm_engine.get_num_unfinished_requests())
100+
if not self.llm_engine.has_unfinished_requests():
101+
logger.info("Broadcast stop")
102+
broadcast_tensor_dict({}, src=0)
103+
self.need_restart = True
104+
for output in step_outputs:
105+
assert len(output.outputs) == 1 # type: ignore
106+
first_out = output.outputs[0] # type: ignore
107+
stats = None
108+
result = first_out.text
109+
tokens[output.request_id] = tokens.get(
110+
output.request_id, 0) + len(first_out.token_ids)
111+
if output.prompt_token_ids is not None:
112+
prompt_lens[output.request_id] = len(
113+
output.prompt_token_ids)
114+
if output.finished:
115+
assert output.request_id in prompt_lens
116+
stats = {
117+
"prompt": prompt_lens[output.request_id],
118+
"tokens": tokens[output.request_id],
119+
"finish_reason": first_out.finish_reason,
120+
"stop_reason": first_out.stop_reason,
121+
}
122+
del prompt_lens[output.request_id]
123+
self.result_queue.put_nowait(
124+
(output.request_id, result, stats))
125+
126+
except Exception as e:
127+
logger.error("Error in run_engine: %s", e)
128+
raise e

0 commit comments

Comments
 (0)