Skip to content

Commit f838fe2

Browse files
committed
scale out
Signed-off-by: Rui Qiao <[email protected]>
1 parent 0de3738 commit f838fe2

File tree

3 files changed

+117
-19
lines changed

3 files changed

+117
-19
lines changed

vllm/entrypoints/cli/serve.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@
2828
from vllm.v1.executor.abstract import Executor
2929
from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus
3030
from vllm.v1.utils import (APIServerProcessManager, CoreEngine,
31-
get_engine_client_zmq_addr,
31+
CoreEngineActorManager, get_engine_client_zmq_addr,
3232
wait_for_completion_or_failure,
33-
wait_for_engine_startup)
33+
wait_for_engine_startup, wait_for_ray_engine_actors)
3434

3535
logger = init_logger(__name__)
3636

@@ -212,6 +212,34 @@ def run_multi_api_server(args: argparse.Namespace):
212212
logger.info("Started DP Coordinator process (PID: %d)",
213213
coordinator.proc.pid)
214214

215+
if parallel_config.data_parallel_backend == "ray":
216+
logger.info("Starting ray-based data parallel backend")
217+
218+
engine_actor_manager = CoreEngineActorManager(
219+
local_engine_count=local_engine_count,
220+
start_index=args.data_parallel_start_rank,
221+
local_start_index=0,
222+
vllm_config=vllm_config,
223+
addresses=addresses,
224+
executor_class=Executor.get_class(vllm_config),
225+
log_stats=not engine_args.disable_log_stats,
226+
)
227+
# Start API servers using the manager
228+
api_server_manager = APIServerProcessManager(
229+
target_server_fn=run_api_server_worker,
230+
listen_address=listen_address,
231+
sock=sock,
232+
args=args,
233+
num_servers=num_api_servers,
234+
input_addresses=input_addresses,
235+
output_addresses=output_addresses,
236+
stats_update_address=stats_update_address)
237+
238+
wait_for_ray_engine_actors(api_server_manager=api_server_manager,
239+
engine_actor_manager=engine_actor_manager,
240+
coordinator=coordinator)
241+
return
242+
215243
handshake_address = get_engine_client_zmq_addr(
216244
local_only, host, parallel_config.data_parallel_rpc_port)
217245

vllm/v1/engine/core_client.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1144,17 +1144,6 @@ def __init__(
11441144
self.stats_update_address = \
11451145
coordinator.get_stats_publish_address()
11461146

1147-
# Start all engines.
1148-
self.resources.local_engine_manager = CoreEngineActorManager(
1149-
vllm_config=vllm_config,
1150-
executor_class=executor_class,
1151-
log_stats=log_stats,
1152-
input_address=input_address,
1153-
output_address=output_address,
1154-
local_engine_count=local_engine_count,
1155-
start_index=start_index,
1156-
local_start_index=local_start_index)
1157-
11581147
self.core_engine = self.core_engines[0]
11591148

11601149
self.utility_results: dict[int, AnyFuture] = {}
@@ -1180,3 +1169,36 @@ def __init__(
11801169
self._ensure_output_queue_task()
11811170
except RuntimeError:
11821171
pass
1172+
1173+
def _init_engines_direct(self, vllm_config: VllmConfig, local_only: bool,
1174+
local_start_index: int, input_address: str,
1175+
output_address: str,
1176+
executor_class: type[Executor], log_stats: bool):
1177+
"""Self-contained client mode, launch engine and coordinator process
1178+
as needed."""
1179+
1180+
parallel_config = vllm_config.parallel_config
1181+
local_engine_count = parallel_config.data_parallel_size_local
1182+
start_index = parallel_config.data_parallel_rank
1183+
1184+
if len(self.core_engines) > 1:
1185+
self.resources.coordinator = DPCoordinator(parallel_config)
1186+
1187+
addresses: dict[str, Any] = {
1188+
"input_addresses": [input_address],
1189+
"output_addresses": [output_address],
1190+
}
1191+
1192+
coordinator = self.resources.coordinator
1193+
if coordinator is not None:
1194+
addresses.update(coordinator.get_engine_socket_addresses())
1195+
1196+
# Start all engines.
1197+
self.resources.local_engine_manager = CoreEngineActorManager(
1198+
vllm_config=vllm_config,
1199+
executor_class=executor_class,
1200+
log_stats=log_stats,
1201+
addresses=addresses,
1202+
local_engine_count=local_engine_count,
1203+
start_index=start_index,
1204+
local_start_index=local_start_index)

vllm/v1/utils.py

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -262,8 +262,7 @@ def __init__(
262262
start_index: int,
263263
local_start_index: int,
264264
vllm_config: VllmConfig,
265-
input_address: str,
266-
output_address: str,
265+
addresses,
267266
executor_class: type[Executor],
268267
log_stats: bool,
269268
):
@@ -284,8 +283,7 @@ def __init__(
284283
vllm_config=vllm_config,
285284
executor_class=executor_class,
286285
log_stats=log_stats,
287-
input_address=input_address,
288-
output_address=output_address,
286+
addresses=addresses,
289287
on_head_node=True,
290288
engine_index=global_index,
291289
dp_rank=global_index,
@@ -301,8 +299,7 @@ def __init__(
301299
vllm_config=vllm_config,
302300
executor_class=executor_class,
303301
log_stats=log_stats,
304-
input_address=input_address,
305-
output_address=output_address,
302+
addresses=addresses,
306303
on_head_node=False,
307304
engine_index=global_index,
308305
dp_rank=global_index,
@@ -490,6 +487,57 @@ def wait_for_completion_or_failure(
490487
local_engine_manager.close()
491488

492489

490+
def wait_for_ray_engine_actors(
491+
api_server_manager: APIServerProcessManager,
492+
engine_actor_manager: CoreEngineActorManager,
493+
coordinator: Optional["DPCoordinator"] = None) -> None:
494+
"""Wait for all ray engine actors to complete or detect if any fail.
495+
496+
Raises an exception if any process exits with a non-zero status.
497+
"""
498+
499+
try:
500+
logger.info("Waiting for ray engine actors to complete ...")
501+
# Create a mapping of sentinels to their corresponding processes
502+
# for efficient lookup
503+
sentinel_to_proc: dict[Any, Union[SpawnProcess, Process]] = {
504+
proc.sentinel: proc
505+
for proc in api_server_manager.processes
506+
}
507+
508+
if coordinator:
509+
sentinel_to_proc.update(
510+
{coordinator.proc.sentinel: coordinator.proc})
511+
512+
# TODO(rui): check if any ray engine actor terminates
513+
# Check if any process terminates
514+
while sentinel_to_proc:
515+
# Wait for any process to terminate
516+
ready_sentinels: list[Any] = connection.wait(sentinel_to_proc)
517+
518+
# Process any terminated processes
519+
for sentinel in ready_sentinels:
520+
proc = sentinel_to_proc.pop(sentinel)
521+
522+
# Check if process exited with error
523+
if proc.exitcode != 0:
524+
raise RuntimeError(
525+
f"Process {proc.name} (PID: {proc.pid}) "
526+
f"died with exit code {proc.exitcode}")
527+
except KeyboardInterrupt:
528+
logger.info("Received KeyboardInterrupt, shutting down API servers...")
529+
except Exception as e:
530+
logger.exception("Exception occurred while running API servers: %s",
531+
str(e))
532+
raise
533+
finally:
534+
logger.info("Terminating remaining processes ...")
535+
api_server_manager.close()
536+
if coordinator:
537+
coordinator.close()
538+
engine_actor_manager.close()
539+
540+
493541
# Note(rob): shutdown function cannot be a bound method,
494542
# else the gc cannot collect the object.
495543
def shutdown(procs: list[Process]):

0 commit comments

Comments
 (0)