From 88025215dc692ac28e1abb79e18892760d53f109 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Wed, 2 Apr 2025 14:07:18 -0700 Subject: [PATCH 01/22] [V1] DP scale-out (2/N): Decouple engine process management and comms Signed-off-by: Nick Hill --- vllm/config.py | 5 + vllm/engine/arg_utils.py | 7 + vllm/v1/engine/core.py | 98 +++++++++----- vllm/v1/engine/core_client.py | 247 ++++++++++++++++++++-------------- vllm/v1/utils.py | 22 ++- 5 files changed, 232 insertions(+), 147 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 2669d1a13b3..2d9854f865b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1430,6 +1430,7 @@ class ParallelConfig: pipeline_parallel_size: int = 1 # Number of pipeline parallel groups. tensor_parallel_size: int = 1 # Number of tensor parallel groups. data_parallel_size: int = 1 # Number of data parallel groups. + data_parallel_size_local: int = 1 # Number of data parallel groups. data_parallel_rank: int = 0 # Rank of the data parallel group. # Local rank of the data parallel group, defaults to global rank. data_parallel_rank_local: Optional[int] = None @@ -1537,6 +1538,10 @@ def __post_init__(self) -> None: self.world_size = self.pipeline_parallel_size * \ self.tensor_parallel_size + if not (0 < self.data_parallel_size_local <= self.data_parallel_size): + raise ValueError( + "data_parallel_size_local must be <= data_parallel_size") + if self.data_parallel_size > 1: # Data parallel was specified in the engine args. self.data_parallel_master_port = get_open_port() diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 89c9b67470e..d6965a08ef0 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -116,6 +116,7 @@ class EngineArgs: pipeline_parallel_size: int = 1 tensor_parallel_size: int = 1 data_parallel_size: int = 1 + data_parallel_size_local: Optional[int] = None enable_expert_parallel: bool = False max_parallel_loading_workers: Optional[int] = None block_size: Optional[int] = None @@ -1186,10 +1187,16 @@ def create_engine_config( # but we should not do this here. placement_group = ray.util.get_current_placement_group() + # Local DP size defaults to global DP size if not set. + data_parallel_size_local = self.data_parallel_size if ( + self.data_parallel_size_local + is None) else self.data_parallel_size_local + parallel_config = ParallelConfig( pipeline_parallel_size=self.pipeline_parallel_size, tensor_parallel_size=self.tensor_parallel_size, data_parallel_size=self.data_parallel_size, + data_parallel_size_local=data_parallel_size_local, enable_expert_parallel=self.enable_expert_parallel, max_parallel_loading_workers=self.max_parallel_loading_workers, disable_custom_all_reduce=self.disable_custom_all_reduce, diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index f58c77e4f16..b7fc3eb6ff7 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -22,8 +22,8 @@ from vllm.lora.request import LoRARequest from vllm.transformers_utils.config import ( maybe_register_config_serialize_by_value) -from vllm.utils import (get_exception_traceback, resolve_obj_by_qualname, - zmq_socket_ctx) +from vllm.utils import (get_exception_traceback, make_zmq_socket, + resolve_obj_by_qualname, zmq_socket_ctx) from vllm.v1.core.kv_cache_utils import (get_kv_cache_config, unify_kv_cache_configs) from vllm.v1.core.sched.interface import SchedulerInterface @@ -309,9 +309,9 @@ class EngineCoreProc(EngineCore): def __init__( self, - input_path: str, - output_path: str, vllm_config: VllmConfig, + on_head_node: bool, + input_address: str, executor_class: type[Executor], log_stats: bool, engine_index: int = 0, @@ -323,6 +323,19 @@ def __init__( self.global_unfinished_reqs = False + # Create input socket. + input_ctx = zmq.Context() # type: ignore[attr-defined] + identity = engine_index.to_bytes(length=2, byteorder="little") + input_socket = make_zmq_socket(input_ctx, + input_address, + zmq.DEALER, + identity=identity, + bind=False) + + # Register engine with front-end. + output_address = self.startup_handshake(input_socket, on_head_node, + vllm_config.parallel_config) + # Background Threads and Queues for IO. These enable us to # overlap ZMQ socket IO with GPU since they release the GIL, # and to overlap some serialization/deserialization with the @@ -332,12 +345,39 @@ def __init__( Any]] = queue.Queue() self.output_queue: queue.Queue[EngineCoreOutputs] = queue.Queue() threading.Thread(target=self.process_input_socket, - args=(input_path, engine_index), + args=(input_socket, ), daemon=True).start() threading.Thread(target=self.process_output_socket, - args=(output_path, engine_index), + args=(output_address, engine_index), daemon=True).start() + @staticmethod + def startup_handshake(input_socket: zmq.Socket, on_head_node: bool, + parallel_config: ParallelConfig) -> str: + + # Send registration message. + input_socket.send( + msgspec.msgpack.encode({ + "local": on_head_node, + "status": "READY" + })) + + # Receive initialization message. + logger.info("Waiting for init message from front-end.") + input_socket.poll(timeout=5 * 60 * 1000) + init_bytes = input_socket.recv() + init_message = msgspec.msgpack.decode(init_bytes) + logger.debug("Received init message: %s", init_message) + + output_socket_address = init_message["output_socket_address"] + #TBD maybe replace IP with configured head node address + + received_parallel_config = init_message["parallel_config"] + for key, value in received_parallel_config.items(): + setattr(parallel_config, key, value) + + return output_socket_address + @staticmethod def run_engine_core(*args, dp_rank: int = 0, @@ -472,35 +512,25 @@ def _convert_msgspec_args(method, args): and not isinstance(v, p.annotation) else v for v, p in zip(args, arg_types)) - def process_input_socket(self, input_path: str, engine_index: int): + def process_input_socket(self, input_socket: zmq.Socket): """Input socket IO thread.""" # Msgpack serialization decoding. add_request_decoder = MsgpackDecoder(EngineCoreRequest) generic_decoder = MsgpackDecoder() - identity = engine_index.to_bytes(length=2, byteorder="little") - - with zmq_socket_ctx(input_path, - zmq.DEALER, - identity=identity, - bind=False) as socket: - # Send ready message to front-end once input socket is connected. - socket.send(b'READY') - - while True: - # (RequestType, RequestData) - type_frame, data_frame = socket.recv_multipart(copy=False) - request_type = EngineCoreRequestType(bytes(type_frame.buffer)) + while True: + # (RequestType, RequestData) + type_frame, data_frame = input_socket.recv_multipart(copy=False) + request_type = EngineCoreRequestType(bytes(type_frame.buffer)) - # Deserialize the request data. - decoder = add_request_decoder if ( - request_type - == EngineCoreRequestType.ADD) else generic_decoder - request = decoder.decode(data_frame.buffer) + # Deserialize the request data. + decoder = add_request_decoder if ( + request_type == EngineCoreRequestType.ADD) else generic_decoder + request = decoder.decode(data_frame.buffer) - # Push to input queue for core busy loop. - self.input_queue.put_nowait((request_type, request)) + # Push to input queue for core busy loop. + self.input_queue.put_nowait((request_type, request)) def process_output_socket(self, output_path: str, engine_index: int): """Output socket IO thread.""" @@ -527,9 +557,9 @@ class DPEngineCoreProc(EngineCoreProc): def __init__( self, - input_path: str, - output_path: str, vllm_config: VllmConfig, + on_head_node: bool, + input_address: str, executor_class: type[Executor], log_stats: bool, ): @@ -551,17 +581,17 @@ def __init__( from vllm.platforms import current_platform if current_platform.is_cuda_alike(): from vllm.platforms.cuda import device_id_to_physical_device_id - tp_size = vllm_config.parallel_config.tensor_parallel_size + world_size = vllm_config.parallel_config.world_size os.environ["CUDA_VISIBLE_DEVICES"] = ",".join( str(device_id_to_physical_device_id(i)) - for i in range(local_dp_rank * tp_size, (local_dp_rank + 1) * - tp_size)) + for i in range(local_dp_rank * + world_size, (local_dp_rank + 1) * world_size)) self.dp_group = vllm_config.parallel_config.stateless_init_dp_group() # Initialize the engine after setting up environment. - super().__init__(input_path, output_path, vllm_config, executor_class, - log_stats, dp_rank) + super().__init__(vllm_config, on_head_node, input_address, + executor_class, log_stats, dp_rank) # Counts forward-passes of the model so that we can synchronize # finished with DP peers every N steps. diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index b94b0aa7538..ace6470faae 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -10,18 +10,20 @@ from abc import ABC, abstractmethod from collections.abc import Awaitable from concurrent.futures import Future -from dataclasses import dataclass, field +from dataclasses import dataclass from threading import Thread from typing import Any, Callable, Optional, TypeVar, Union +import msgspec import zmq import zmq.asyncio -from vllm.config import VllmConfig +from vllm.config import ParallelConfig, VllmConfig from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.utils import (get_open_zmq_inproc_path, get_open_zmq_ipc_path, - kill_process_tree, make_zmq_socket) +from vllm.utils import (get_open_port, get_open_zmq_inproc_path, + get_open_zmq_ipc_path, kill_process_tree, + make_zmq_socket) from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, EngineCoreRequestType, UtilityOutput) from vllm.v1.engine.core import EngineCore, EngineCoreProc @@ -255,46 +257,59 @@ def collective_rpc(self, return self.engine_core.collective_rpc(method, timeout, args, kwargs) -class CoreEngine: +class CoreEngineProcManager: """One per data parallel rank.""" def __init__( self, + local_engine_count: int, + start_index: int, vllm_config: VllmConfig, + on_head_node: bool, + input_address: str, executor_class: type[Executor], log_stats: bool, - input_path: str, - output_path: str, - index: int = 0, - local_dp_rank: int = 0, ): - self.index = index - self.identity = index.to_bytes(length=2, byteorder="little") + self.proc_handles = [] try: - # Start EngineCore in background process. - self.proc_handle = BackgroundProcHandle( - input_path=input_path, - output_path=output_path, - process_name=f"EngineCore_{index}", - target_fn=EngineCoreProc.run_engine_core, - process_kwargs={ - "vllm_config": vllm_config, - "dp_rank": index, - "local_dp_rank": local_dp_rank, - "executor_class": executor_class, - "log_stats": log_stats, - }) - - self.num_reqs_in_flight = 0 + for local_index in range(local_engine_count): + index = local_index + start_index + # Start EngineCore in background process. + self.proc_handles.append( + BackgroundProcHandle( + input_address=input_address, + process_name=f"EngineCore_{index}", + target_fn=EngineCoreProc.run_engine_core, + process_kwargs={ + "vllm_config": vllm_config, + "on_head_node": on_head_node, + "dp_rank": index, + "local_dp_rank": local_index, + "executor_class": executor_class, + "log_stats": log_stats, + })) finally: - if not hasattr(self, "num_reqs_in_flight"): - # Ensure socket is closed if process fails to start. + if len(self.proc_handles) != local_engine_count: self.close() def close(self): - if proc_handle := getattr(self, "proc_handle", None): + for proc_handle in self.proc_handles: proc_handle.shutdown() + def finished_procs(self) -> dict[int, int]: + return { + handle.proc.name: handle.proc.exitcode + for handle in self.proc_handles if handle.proc.exitcode is not None + } + + +class CoreEngine: + """One per data parallel rank.""" + + def __init__(self, index: int = 0): + self.identity = index.to_bytes(length=2, byteorder="little") + self.num_reqs_in_flight = 0 + @dataclass class BackgroundResources: @@ -302,7 +317,7 @@ class BackgroundResources: circular reference back to the client object.""" ctx: Union[zmq.Context] - core_engines: list[CoreEngine] = field(default_factory=list) + local_engine_manager: Optional[CoreEngineProcManager] = None output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None shutdown_path: Optional[str] = None @@ -310,8 +325,8 @@ class BackgroundResources: def __call__(self): """Clean up background resources.""" - for core_engine in self.core_engines: - core_engine.close() + if self.local_engine_manager is not None: + self.local_engine_manager.close() # ZMQ context termination can hang if the sockets # aren't explicitly closed first. @@ -383,67 +398,111 @@ def sigusr1_handler(signum, frame): self.resources = BackgroundResources(ctx=sync_ctx) self._finalizer = weakref.finalize(self, self.resources) - # Paths and sockets for IPC. - self.output_path = get_open_zmq_ipc_path() - input_path = get_open_zmq_ipc_path() - self.input_socket = make_zmq_socket(self.ctx, - input_path, - zmq.ROUTER, - bind=True) - self.resources.input_socket = self.input_socket - - new_core_engine = lambda index, local_dp_rank=None: CoreEngine( - vllm_config, executor_class, log_stats, input_path, self. - output_path, index, local_dp_rank) + # TODO + parallel_config = vllm_config.parallel_config + dp_size = parallel_config.data_parallel_size + local_engine_count = parallel_config.data_parallel_size_local - # Start engine core process(es). - self._init_core_engines(vllm_config, new_core_engine, - self.resources.core_engines) + # TODO somewhere validate local count <= dp_size + if local_engine_count == dp_size: + input_address = get_open_zmq_ipc_path() + output_address = get_open_zmq_ipc_path() + else: + host = parallel_config.data_parallel_master_ip + input_port = 13345 # todo from arg/config + output_port = get_open_port() + input_address = f"tcp://{host}:{input_port}" + output_address = f"tcp://{host}:{output_port}" + + # Create input and output sockets. + self.input_socket = self.resources.input_socket = make_zmq_socket( + self.ctx, input_address, zmq.ROUTER, bind=True) + + self.resources.output_socket = make_zmq_socket(self.ctx, + output_address, + zmq.constants.PULL) + + # Start local engines. + if local_engine_count: + self.resources.local_engine_manager = CoreEngineProcManager( + vllm_config=vllm_config, + executor_class=executor_class, + log_stats=log_stats, + input_address=input_address, + on_head_node=True, + local_engine_count=local_engine_count, + start_index=0) + + self.core_engines = [CoreEngine(i) for i in range(dp_size)] + self.core_engine = self.core_engines[0] # Wait for engine core process(es) to start. - self._wait_for_engine_startup() + self._wait_for_engine_startup(output_address, parallel_config) self.utility_results: dict[int, AnyFuture] = {} - def _wait_for_engine_startup(self): + def _wait_for_engine_startup(self, output_address: str, + parallel_config: ParallelConfig): # Get a sync handle to the socket which can be sync or async. sync_input_socket = zmq.Socket.shadow(self.input_socket) # Wait for engine core process(es) to send ready messages. - identities = set(eng.index for eng in self.resources.core_engines) - while identities: + local_engine_count = parallel_config.data_parallel_size_local + # TODO offline case compatibility + local_indices = set(range(local_engine_count)) + remote_indices = set( + range(len(self.core_engines) - local_engine_count)) + while local_indices or remote_indices: while not sync_input_socket.poll(timeout=STARTUP_POLL_PERIOD_MS): - logger.info("Waiting for %d core engine proc(s) to start: %s", - len(identities), identities) - eng_id_bytes, msg = sync_input_socket.recv_multipart() - eng_id = int.from_bytes(eng_id_bytes, byteorder="little") - if eng_id not in identities: - raise RuntimeError(f"Unexpected or duplicate engine: {eng_id}") - if msg != b'READY': - raise RuntimeError(f"Engine {eng_id} failed: {msg.decode()}") - logger.info("Core engine process %d ready.", eng_id) - identities.discard(eng_id) + local_count = len(local_indices) + if remote_indices: + remote_count = len(remote_indices) + logger.info( + "Waiting for %d local and %d remote core engine " + "proc(s) to start: %s, %s", local_count, remote_count, + local_indices, remote_indices) + else: + logger.info( + "Waiting for %d local core engine proc(s) " + "to start: %s", local_count, local_indices) + eng_identity, ready_msg_bytes = sync_input_socket.recv_multipart() + ready_msg = msgspec.msgpack.decode(ready_msg_bytes) + local, status = ready_msg["local"], ready_msg["status"] + eng_index = int.from_bytes(eng_identity, byteorder="little") + if status != "READY": + raise RuntimeError(f"{'Local' if local else 'Remote'} engine " + f"{eng_index} failed: {status}") + + index_set = local_indices if local else remote_indices + if eng_index not in index_set: + raise RuntimeError( + f"Unexpected or duplicate " + f"{'local' if local else 'remote'} engine: {eng_index}") + + # Send init message with DP config info. + init_message = self.encoder.encode({ + "output_socket_address": output_address, + "parallel_config": { + "data_parallel_master_ip": + parallel_config.data_parallel_master_ip, + "data_parallel_master_port": + parallel_config.data_parallel_master_port, + "data_parallel_size": parallel_config.data_parallel_size, + }, + }) + + sync_input_socket.send_multipart((eng_identity, init_message), + copy=False) + + logger.debug("%s core engine process %d ready.", + "Local" if local else "Remote", eng_index) + index_set.discard(eng_index) # Double check that the process are running. - for engine in self.resources.core_engines: - proc = engine.proc_handle.proc - if proc.exitcode is not None: - raise RuntimeError(f"Engine proc {proc.name} not running") - - def _init_core_engines( - self, - vllm_config: VllmConfig, - new_core_engine: Callable[[int, Optional[int]], CoreEngine], - core_engines: list[CoreEngine], - ) -> None: - - # Default case - single core engine. - dp_rank = vllm_config.parallel_config.data_parallel_rank - local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local - core_engine = new_core_engine( - dp_rank, local_dp_rank if local_dp_rank is not None else dp_rank) - core_engines.append(core_engine) - self.core_engine = core_engine + engine_manager = self.resources.local_engine_manager + if engine_manager and (procs := engine_manager.finished_procs()): + raise RuntimeError( + f"Local engine proc(s) exited unexpectedly: {procs}") def shutdown(self): self._finalizer() @@ -476,7 +535,8 @@ def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor], # Ensure that the outputs socket processing thread does not have # a ref to the client which prevents gc. ctx = self.ctx - output_path = self.output_path + out_socket = self.resources.output_socket + assert out_socket is not None decoder = self.decoder utility_results = self.utility_results outputs_queue = self.outputs_queue @@ -486,7 +546,6 @@ def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor], def process_outputs_socket(): shutdown_socket = ctx.socket(zmq.PAIR) - out_socket = make_zmq_socket(ctx, output_path, zmq.constants.PULL) try: shutdown_socket.bind(shutdown_path) poller = zmq.Poller() @@ -518,6 +577,9 @@ def process_outputs_socket(): daemon=True) self.output_queue_thread.start() + # The thread takes on responsibility for closing the socket. + self.resources.output_socket = None + def get_output(self) -> EngineCoreOutputs: return self.outputs_queue.get() @@ -621,10 +683,8 @@ def _ensure_output_queue_task(self): outputs_queue = self.outputs_queue output_handler = self.outputs_handler _self_ref = weakref.ref(self) if output_handler else None - output_path = self.output_path - output_socket = make_zmq_socket(self.ctx, output_path, - zmq.constants.PULL) - self.resources.output_socket = output_socket + output_socket = self.resources.output_socket + assert output_socket is not None async def process_outputs_socket(): while True: @@ -762,21 +822,6 @@ def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor], self.outputs_handler = DPAsyncMPClient.process_engine_outputs # type: ignore[assignment] - def _init_core_engines( - self, - vllm_config: VllmConfig, - new_core_engine: Callable[[int, Optional[int]], CoreEngine], - core_engines: list[CoreEngine], - ) -> None: - - # Launch a core engine for each data parallel rank. - dp_size = vllm_config.parallel_config.data_parallel_size - for i in range(dp_size): - # Multi-node not yet supported so local_dp_rank == dp_rank. - core_engines.append(new_core_engine(i, i)) - - self.core_engines = core_engines - async def call_utility_async(self, method: str, *args) -> Any: # Only the result from the first engine is returned. return (await asyncio.gather(*[ diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index fed5761b04b..470e2a572ed 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -98,25 +98,22 @@ class BackgroundProcHandle: def __init__( self, - input_path: str, - output_path: str, + input_address: str, process_name: str, target_fn: Callable, process_kwargs: dict[Any, Any], ): context = get_mp_context() - assert ("input_path" not in process_kwargs - and "output_path" not in process_kwargs) - process_kwargs["input_path"] = input_path - process_kwargs["output_path"] = output_path + assert "input_address" not in process_kwargs + process_kwargs["input_address"] = input_address # Run busy loop in background process. self.proc = context.Process(target=target_fn, kwargs=process_kwargs, name=process_name) self._finalizer = weakref.finalize(self, shutdown, self.proc, - input_path, output_path) + input_address) self.proc.start() def shutdown(self): @@ -125,7 +122,7 @@ def shutdown(self): # Note(rob): shutdown function cannot be a bound method, # else the gc cannot collect the object. -def shutdown(proc: multiprocessing.Process, input_path: str, output_path: str): +def shutdown(proc: multiprocessing.Process, input_address: str): # Shutdown the process. if proc.is_alive(): proc.terminate() @@ -135,11 +132,12 @@ def shutdown(proc: multiprocessing.Process, input_path: str, output_path: str): kill_process_tree(proc.pid) # Remove zmq ipc socket files. - ipc_sockets = [output_path, input_path] + ipc_sockets = (input_address, ) for ipc_socket in ipc_sockets: - socket_file = ipc_socket.replace("ipc://", "") - if os and os.path.exists(socket_file): - os.remove(socket_file) + if ipc_socket.startswith("ipc://"): + socket_file = ipc_socket.replace("ipc://", "") + if os and os.path.exists(socket_file): + os.remove(socket_file) def bind_kv_cache( From e86938050a9f22b4c50446cb7f26a0ad7ac3d781 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 3 Apr 2025 11:19:27 -0700 Subject: [PATCH 02/22] Headless mode Signed-off-by: Nick Hill --- vllm/config.py | 3 +- vllm/engine/arg_utils.py | 33 ++++++++ vllm/entrypoints/cli/serve.py | 62 +++++++++++++- vllm/v1/engine/core.py | 52 ++++++++---- vllm/v1/engine/core_client.py | 155 +++++++++++++++------------------- vllm/v1/utils.py | 101 ++++++++++++++++------ 6 files changed, 270 insertions(+), 136 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 2d9854f865b..57ab0ef0596 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1436,6 +1436,7 @@ class ParallelConfig: data_parallel_rank_local: Optional[int] = None # IP of the data parallel master. data_parallel_master_ip: str = "127.0.0.1" + data_parallel_rpc_port: int = 29550 # Port for data parallel messaging. data_parallel_master_port: int = 29500 # Port of the data parallel master. enable_expert_parallel: bool = False # Use EP instead of TP for MoE layers. @@ -1538,7 +1539,7 @@ def __post_init__(self) -> None: self.world_size = self.pipeline_parallel_size * \ self.tensor_parallel_size - if not (0 < self.data_parallel_size_local <= self.data_parallel_size): + if self.data_parallel_size_local > self.data_parallel_size: raise ValueError( "data_parallel_size_local must be <= data_parallel_size") diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index d6965a08ef0..b7cacb177dc 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -117,6 +117,9 @@ class EngineArgs: tensor_parallel_size: int = 1 data_parallel_size: int = 1 data_parallel_size_local: Optional[int] = None + data_parallel_start_rank: int = 0 + data_parallel_address: Optional[str] = None + data_parallel_rpc_port: Optional[int] = None enable_expert_parallel: bool = False max_parallel_loading_workers: Optional[int] = None block_size: Optional[int] = None @@ -435,6 +438,29 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: 'MoE layers will be sharded according to the ' 'product of the tensor-parallel-size and ' 'data-parallel-size.') + parser.add_argument('--data-parallel-size-local', + '-dpl', + type=int, + default=EngineArgs.data_parallel_size_local, + help='Number of data parallel replicas to run on ' + 'this node.') + parser.add_argument('--data-parallel-start-rank', + '-dpr', + type=int, + default=EngineArgs.data_parallel_start_rank, + help='Starting data parallel rank for secondary ' + 'nodes.') + parser.add_argument('--data-parallel-address', + '-dpa', + type=str, + default=EngineArgs.data_parallel_address, + help='Address of data parallel cluster head-node.') + parser.add_argument('--data-parallel-rpc-port', + '-dpp', + type=int, + default=EngineArgs.data_parallel_rpc_port, + help='Port for data parallel RPC communication.') + parser.add_argument( '--enable-expert-parallel', action='store_true', @@ -1192,11 +1218,18 @@ def create_engine_config( self.data_parallel_size_local is None) else self.data_parallel_size_local + # This port is only used when there are remote data parallel engines, + # otherwise the local IPC transport is used. + data_parallel_rpc_port = self.data_parallel_rpc_port if ( + self.data_parallel_rpc_port + is not None) else (ParallelConfig.data_parallel_rpc_port) + parallel_config = ParallelConfig( pipeline_parallel_size=self.pipeline_parallel_size, tensor_parallel_size=self.tensor_parallel_size, data_parallel_size=self.data_parallel_size, data_parallel_size_local=data_parallel_size_local, + data_parallel_rpc_port=data_parallel_rpc_port, enable_expert_parallel=self.enable_expert_parallel, max_parallel_loading_workers=self.max_parallel_loading_workers, disable_custom_all_reduce=self.disable_custom_all_reduce, diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index e89ac4e2199..801dd6db3d7 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -4,11 +4,20 @@ import uvloop +import vllm.envs as envs +from vllm import AsyncEngineArgs from vllm.entrypoints.cli.types import CLISubcommand from vllm.entrypoints.openai.api_server import run_server from vllm.entrypoints.openai.cli_args import (make_arg_parser, validate_parsed_serve_args) +from vllm.logger import init_logger +from vllm.usage.usage_lib import UsageContext from vllm.utils import FlexibleArgumentParser +from vllm.v1.engine.core import EngineCoreProc +from vllm.v1.engine.core_client import CoreEngineProcManager +from vllm.v1.executor.abstract import Executor + +logger = init_logger(__name__) class ServeSubcommand(CLISubcommand): @@ -24,7 +33,10 @@ def cmd(args: argparse.Namespace) -> None: if hasattr(args, 'model_tag') and args.model_tag is not None: args.model = args.model_tag - uvloop.run(run_server(args)) + if args.headless: + run_headless(args) + else: + uvloop.run(run_server(args)) def validate(self, args: argparse.Namespace) -> None: validate_parsed_serve_args(args) @@ -41,6 +53,12 @@ def subparser_init( nargs='?', help="The model tag to serve " "(optional if specified in config)") + serve_parser.add_argument( + "--headless", + action='store_true', + default=False, + help="Run in headless mode. See multi-node data parallel " + "documentation for more details.") serve_parser.add_argument( "--config", type=str, @@ -56,3 +74,45 @@ def subparser_init( def cmd_init() -> list[CLISubcommand]: return [ServeSubcommand()] + + +def run_headless(args: argparse.Namespace): + + # Create the EngineConfig. + engine_args = AsyncEngineArgs.from_cli_args(args) + usage_context = UsageContext.OPENAI_API_SERVER + vllm_config = engine_args.create_engine_config(usage_context=usage_context) + + if not envs.VLLM_USE_V1: + raise RuntimeError("Headless mode is only supported for V1") + + parallel_config = vllm_config.parallel_config + local_engine_count = parallel_config.data_parallel_size_local + host = parallel_config.data_parallel_master_ip + port = engine_args.data_parallel_rpc_port # add to config too + input_address = f"tcp://{host}:{port}" + + if local_engine_count <= 0: + raise RuntimeError("data_parallel_size_local must be > 0 in " + "headless mode") + + logger.info( + "Launching %d data parallel engine(s) in headless mode, " + "with head node address %s.", local_engine_count, input_address) + + # Create the engines. + engine_manager = CoreEngineProcManager( + target_fn=EngineCoreProc.run_engine_core, + local_engine_count=local_engine_count, + start_index=engine_args.data_parallel_start_rank, + vllm_config=vllm_config, + on_head_node=False, + input_address=input_address, + executor_class=Executor.get_class(vllm_config), + log_stats=not engine_args.disable_log_stats, + ) + + try: + engine_manager.join_first() + finally: + engine_manager.close() diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index b7fc3eb6ff7..1399604692e 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -316,13 +316,6 @@ def __init__( log_stats: bool, engine_index: int = 0, ): - super().__init__(vllm_config, executor_class, log_stats) - - self.step_fn = (self.step if self.batch_queue is None else - self.step_with_batch_queue) - - self.global_unfinished_reqs = False - # Create input socket. input_ctx = zmq.Context() # type: ignore[attr-defined] identity = engine_index.to_bytes(length=2, byteorder="little") @@ -336,6 +329,24 @@ def __init__( output_address = self.startup_handshake(input_socket, on_head_node, vllm_config.parallel_config) + # Set up data parallel environment. + self._init_data_parallel(vllm_config) + + # Initialize engine core and model. + super().__init__(vllm_config, executor_class, log_stats) + + self.step_fn = (self.step if self.batch_queue is None else + self.step_with_batch_queue) + + self.global_unfinished_reqs = False + + # Send ready message. + input_socket.send( + msgspec.msgpack.encode({ + "status": "READY", + "local": on_head_node + })) + # Background Threads and Queues for IO. These enable us to # overlap ZMQ socket IO with GPU since they release the GIL, # and to overlap some serialization/deserialization with the @@ -358,8 +369,8 @@ def startup_handshake(input_socket: zmq.Socket, on_head_node: bool, # Send registration message. input_socket.send( msgspec.msgpack.encode({ + "status": "HELLO", "local": on_head_node, - "status": "READY" })) # Receive initialization message. @@ -430,6 +441,9 @@ def signal_handler(signum, frame): if engine_core is not None: engine_core.shutdown() + def _init_data_parallel(self, vllm_config: VllmConfig): + pass + def run_busy_loop(self): """Core busy loop of the EngineCore.""" @@ -571,8 +585,20 @@ def __init__( _add_prefix(sys.stdout, process_name, pid) _add_prefix(sys.stderr, process_name, pid) - dp_size = vllm_config.parallel_config.data_parallel_size + # Counts forward-passes of the model so that we can synchronize + # finished with DP peers every N steps. + self.counter = 0 + + # Initialize the engine. dp_rank = vllm_config.parallel_config.data_parallel_rank + super().__init__(vllm_config, on_head_node, input_address, + executor_class, log_stats, dp_rank) + + def _init_data_parallel(self, vllm_config: VllmConfig): + + # Configure GPUs and stateless process group for data parallel. + dp_rank = vllm_config.parallel_config.data_parallel_rank + dp_size = vllm_config.parallel_config.data_parallel_size local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local assert dp_size > 1 @@ -589,14 +615,6 @@ def __init__( self.dp_group = vllm_config.parallel_config.stateless_init_dp_group() - # Initialize the engine after setting up environment. - super().__init__(vllm_config, on_head_node, input_address, - executor_class, log_stats, dp_rank) - - # Counts forward-passes of the model so that we can synchronize - # finished with DP peers every N steps. - self.counter = 0 - def shutdown(self): super().shutdown() if dp_group := getattr(self, "dp_group", None): diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index ace6470faae..c41ef85a175 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -29,7 +29,7 @@ from vllm.v1.engine.core import EngineCore, EngineCoreProc from vllm.v1.executor.abstract import Executor from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder -from vllm.v1.utils import BackgroundProcHandle +from vllm.v1.utils import CoreEngineProcManager logger = init_logger(__name__) @@ -257,52 +257,6 @@ def collective_rpc(self, return self.engine_core.collective_rpc(method, timeout, args, kwargs) -class CoreEngineProcManager: - """One per data parallel rank.""" - - def __init__( - self, - local_engine_count: int, - start_index: int, - vllm_config: VllmConfig, - on_head_node: bool, - input_address: str, - executor_class: type[Executor], - log_stats: bool, - ): - self.proc_handles = [] - try: - for local_index in range(local_engine_count): - index = local_index + start_index - # Start EngineCore in background process. - self.proc_handles.append( - BackgroundProcHandle( - input_address=input_address, - process_name=f"EngineCore_{index}", - target_fn=EngineCoreProc.run_engine_core, - process_kwargs={ - "vllm_config": vllm_config, - "on_head_node": on_head_node, - "dp_rank": index, - "local_dp_rank": local_index, - "executor_class": executor_class, - "log_stats": log_stats, - })) - finally: - if len(self.proc_handles) != local_engine_count: - self.close() - - def close(self): - for proc_handle in self.proc_handles: - proc_handle.shutdown() - - def finished_procs(self) -> dict[int, int]: - return { - handle.proc.name: handle.proc.exitcode - for handle in self.proc_handles if handle.proc.exitcode is not None - } - - class CoreEngine: """One per data parallel rank.""" @@ -398,18 +352,17 @@ def sigusr1_handler(signum, frame): self.resources = BackgroundResources(ctx=sync_ctx) self._finalizer = weakref.finalize(self, self.resources) - # TODO + # TODO move address setup to separate method parallel_config = vllm_config.parallel_config dp_size = parallel_config.data_parallel_size local_engine_count = parallel_config.data_parallel_size_local - # TODO somewhere validate local count <= dp_size if local_engine_count == dp_size: input_address = get_open_zmq_ipc_path() output_address = get_open_zmq_ipc_path() else: host = parallel_config.data_parallel_master_ip - input_port = 13345 # todo from arg/config + input_port = parallel_config.data_parallel_rpc_port output_port = get_open_port() input_address = f"tcp://{host}:{input_port}" output_address = f"tcp://{host}:{output_port}" @@ -421,10 +374,10 @@ def sigusr1_handler(signum, frame): self.resources.output_socket = make_zmq_socket(self.ctx, output_address, zmq.constants.PULL) - # Start local engines. if local_engine_count: self.resources.local_engine_manager = CoreEngineProcManager( + EngineCoreProc.run_engine_core, vllm_config=vllm_config, executor_class=executor_class, log_stats=log_stats, @@ -446,56 +399,80 @@ def _wait_for_engine_startup(self, output_address: str, # Get a sync handle to the socket which can be sync or async. sync_input_socket = zmq.Socket.shadow(self.input_socket) + # TODO offline case compatibility + # Wait for engine core process(es) to send ready messages. local_engine_count = parallel_config.data_parallel_size_local - # TODO offline case compatibility - local_indices = set(range(local_engine_count)) - remote_indices = set( - range(len(self.core_engines) - local_engine_count)) - while local_indices or remote_indices: + remote_engine_count = len(self.core_engines) - local_engine_count + + # TODO simplify the startup tracking logic below! + pending_hello_local = set(range(local_engine_count)) + pending_hello_remote = set( + range(local_engine_count, len(self.core_engines))) + pending_ready_local = set(pending_hello_local) + pending_ready_remote = set(pending_hello_remote) + while pending_ready_local or pending_ready_remote: while not sync_input_socket.poll(timeout=STARTUP_POLL_PERIOD_MS): - local_count = len(local_indices) - if remote_indices: - remote_count = len(remote_indices) + local_conn = local_engine_count - len(pending_hello_local) + local_ready = local_engine_count - len(pending_ready_local) + if local_ready != local_engine_count: logger.info( - "Waiting for %d local and %d remote core engine " - "proc(s) to start: %s, %s", local_count, remote_count, - local_indices, remote_indices) - else: - logger.info( - "Waiting for %d local core engine proc(s) " - "to start: %s", local_count, local_indices) + "Waiting for local core engine procs: " + "%d/%d connected, %d/%d ready.", local_conn, + local_engine_count, local_ready, local_engine_count) + if remote_engine_count: + remote_conn = remote_engine_count - len( + pending_hello_remote) + remote_ready = remote_engine_count - len( + pending_ready_remote) + if remote_ready != remote_engine_count: + logger.info( + "Waiting for remote core engine procs: " + "%d/%d connected, %d/%d ready.", remote_conn, + remote_engine_count, remote_ready, + remote_engine_count) + + # Receive HELLO and READY messages from the input socket. eng_identity, ready_msg_bytes = sync_input_socket.recv_multipart() - ready_msg = msgspec.msgpack.decode(ready_msg_bytes) - local, status = ready_msg["local"], ready_msg["status"] eng_index = int.from_bytes(eng_identity, byteorder="little") - if status != "READY": + msg = msgspec.msgpack.decode(ready_msg_bytes) + status, local = msg["status"], msg["local"] + hello_set = pending_hello_local if local else pending_hello_remote + ready_set = pending_ready_local if local else pending_ready_remote + if status == "HELLO": + index_set = hello_set + elif status == "READY": + index_set = ready_set + else: raise RuntimeError(f"{'Local' if local else 'Remote'} engine " f"{eng_index} failed: {status}") - - index_set = local_indices if local else remote_indices if eng_index not in index_set: raise RuntimeError( - f"Unexpected or duplicate " + f"Unexpected or duplicate {status} " + f"{'local' if local else 'remote'} engine: {eng_index}") + if status == "READY" and eng_index in hello_set: + raise RuntimeError( + f"Unexpected READY before HELLO for " f"{'local' if local else 'remote'} engine: {eng_index}") - # Send init message with DP config info. - init_message = self.encoder.encode({ - "output_socket_address": output_address, - "parallel_config": { - "data_parallel_master_ip": - parallel_config.data_parallel_master_ip, - "data_parallel_master_port": - parallel_config.data_parallel_master_port, - "data_parallel_size": parallel_config.data_parallel_size, - }, - }) - - sync_input_socket.send_multipart((eng_identity, init_message), - copy=False) - - logger.debug("%s core engine process %d ready.", - "Local" if local else "Remote", eng_index) + if status == "HELLO": + # Send init message with DP config info. + init_message = self.encoder.encode({ + "output_socket_address": output_address, + "parallel_config": { + "data_parallel_master_ip": + parallel_config.data_parallel_master_ip, + "data_parallel_master_port": + parallel_config.data_parallel_master_port, + "data_parallel_size": + parallel_config.data_parallel_size, + }, + }) + sync_input_socket.send_multipart((eng_identity, init_message), + copy=False) + + logger.debug("%s from %s core engine process %s.", status, + "local" if local else "remote", eng_index) index_set.discard(eng_index) # Double check that the process are running. diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 470e2a572ed..adfdb86a705 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -2,17 +2,21 @@ import multiprocessing import os +import time import weakref from collections import defaultdict from collections.abc import Sequence -from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar, - Union, overload) +from multiprocessing import connection +from typing import (TYPE_CHECKING, Callable, Generic, Optional, TypeVar, Union, + overload) import torch +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.models.utils import extract_layer_index from vllm.utils import get_mp_context, kill_process_tree +from vllm.v1.executor.abstract import Executor if TYPE_CHECKING: from vllm.attention.layer import Attention @@ -90,7 +94,7 @@ def __repr__(self): return f"ConstantList({self._x})" -class BackgroundProcHandle: +class CoreEngineProcManager: """ Utility class to handle creation, readiness, and shutdown of background processes used by the AsyncLLM and LLMEngine. @@ -98,46 +102,87 @@ class BackgroundProcHandle: def __init__( self, - input_address: str, - process_name: str, target_fn: Callable, - process_kwargs: dict[Any, Any], + local_engine_count: int, + start_index: int, + vllm_config: VllmConfig, + on_head_node: bool, + input_address: str, + executor_class: type[Executor], + log_stats: bool, ): context = get_mp_context() - - assert "input_address" not in process_kwargs - process_kwargs["input_address"] = input_address - - # Run busy loop in background process. - self.proc = context.Process(target=target_fn, - kwargs=process_kwargs, - name=process_name) - self._finalizer = weakref.finalize(self, shutdown, self.proc, + common_kwargs = { + "vllm_config": vllm_config, + "on_head_node": on_head_node, + "input_address": input_address, + "executor_class": executor_class, + "log_stats": log_stats, + } + + self.processes = [] + for local_index in range(local_engine_count): + index = local_index + start_index + # Start EngineCore in background process. + self.processes.append( + context.Process(target=target_fn, + name=f"EngineCore_{index}", + kwargs=common_kwargs | { + "dp_rank": index, + "local_dp_rank": local_index, + })) + + self._finalizer = weakref.finalize(self, shutdown, self.processes, input_address) - self.proc.start() - - def shutdown(self): + try: + for proc in self.processes: + proc.start() + finally: + # Kill other procs if not all are running. + if self.finished_procs(): + self.close() + + def close(self): + """Shutdown all procs.""" self._finalizer() + def join_first(self): + """Wait for any process to exit.""" + connection.wait(proc.sentinel for proc in self.processes) + + def finished_procs(self) -> dict[int, int]: + """Returns dict of proc name -> exit code for any finished procs.""" + return { + proc.name: proc.exitcode + for proc in self.processes if proc.exitcode is not None + } + # Note(rob): shutdown function cannot be a bound method, # else the gc cannot collect the object. -def shutdown(proc: multiprocessing.Process, input_address: str): +def shutdown(procs: list[multiprocessing.Process], input_address: str): # Shutdown the process. - if proc.is_alive(): - proc.terminate() - proc.join(5) + for proc in procs: + if proc.is_alive(): + proc.terminate() + + # Allow 5 seconds for remaining procs to terminate. + deadline = time.monotonic() + 5 + for proc in procs: + remaining = deadline - time.monotonic() + if remaining <= 0: + break + proc.join(remaining) + for proc in procs: if proc.is_alive(): kill_process_tree(proc.pid) # Remove zmq ipc socket files. - ipc_sockets = (input_address, ) - for ipc_socket in ipc_sockets: - if ipc_socket.startswith("ipc://"): - socket_file = ipc_socket.replace("ipc://", "") - if os and os.path.exists(socket_file): - os.remove(socket_file) + if input_address.startswith("ipc://"): + socket_file = input_address[len("ipc://"):] + if os and os.path.exists(socket_file): + os.remove(socket_file) def bind_kv_cache( From 1ca3d1598f17d17d40d25a5ba3e44072798628d4 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 3 Apr 2025 19:23:39 -0700 Subject: [PATCH 03/22] Wire data_parallel_address arg Signed-off-by: Nick Hill --- vllm/engine/arg_utils.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index b7cacb177dc..729c1be1321 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1218,17 +1218,24 @@ def create_engine_config( self.data_parallel_size_local is None) else self.data_parallel_size_local + # DP address, used in multi-node case for torch distributed group + # and ZMQ sockets. + data_parallel_address = self.data_parallel_address if ( + self.data_parallel_address + is not None) else ParallelConfig.data_parallel_master_ip + # This port is only used when there are remote data parallel engines, # otherwise the local IPC transport is used. data_parallel_rpc_port = self.data_parallel_rpc_port if ( self.data_parallel_rpc_port - is not None) else (ParallelConfig.data_parallel_rpc_port) + is not None) else ParallelConfig.data_parallel_rpc_port parallel_config = ParallelConfig( pipeline_parallel_size=self.pipeline_parallel_size, tensor_parallel_size=self.tensor_parallel_size, data_parallel_size=self.data_parallel_size, data_parallel_size_local=data_parallel_size_local, + data_parallel_master_ip=data_parallel_address, data_parallel_rpc_port=data_parallel_rpc_port, enable_expert_parallel=self.enable_expert_parallel, max_parallel_loading_workers=self.max_parallel_loading_workers, From a5511836d0cee8007d326be57cce7b49340ce2c0 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 3 Apr 2025 21:33:05 -0700 Subject: [PATCH 04/22] Some code cleanup Signed-off-by: Nick Hill --- vllm/v1/engine/core_client.py | 128 +++++++++++++++++++--------------- 1 file changed, 70 insertions(+), 58 deletions(-) diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index c41ef85a175..c5d9b16f8fe 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -11,6 +11,7 @@ from collections.abc import Awaitable from concurrent.futures import Future from dataclasses import dataclass +from enum import Enum, auto from threading import Thread from typing import Any, Callable, Optional, TypeVar, Union @@ -257,11 +258,20 @@ def collective_rpc(self, return self.engine_core.collective_rpc(method, timeout, args, kwargs) +class CoreEngineState(Enum): + NEW = auto() + CONNECTED = auto() + READY = auto() + + class CoreEngine: """One per data parallel rank.""" - def __init__(self, index: int = 0): + def __init__(self, index: int = 0, local: bool = True): + self.local = local self.identity = index.to_bytes(length=2, byteorder="little") + + self.state = CoreEngineState.NEW self.num_reqs_in_flight = 0 @@ -352,20 +362,12 @@ def sigusr1_handler(signum, frame): self.resources = BackgroundResources(ctx=sync_ctx) self._finalizer = weakref.finalize(self, self.resources) - # TODO move address setup to separate method parallel_config = vllm_config.parallel_config dp_size = parallel_config.data_parallel_size local_engine_count = parallel_config.data_parallel_size_local - if local_engine_count == dp_size: - input_address = get_open_zmq_ipc_path() - output_address = get_open_zmq_ipc_path() - else: - host = parallel_config.data_parallel_master_ip - input_port = parallel_config.data_parallel_rpc_port - output_port = get_open_port() - input_address = f"tcp://{host}:{input_port}" - output_address = f"tcp://{host}:{output_port}" + input_address, output_address = self._get_zmq_addresses( + parallel_config) # Create input and output sockets. self.input_socket = self.resources.input_socket = make_zmq_socket( @@ -386,7 +388,10 @@ def sigusr1_handler(signum, frame): local_engine_count=local_engine_count, start_index=0) - self.core_engines = [CoreEngine(i) for i in range(dp_size)] + self.core_engines = [ + CoreEngine(index=i, local=(i < local_engine_count)) + for i in range(dp_size) + ] self.core_engine = self.core_engines[0] # Wait for engine core process(es) to start. @@ -394,6 +399,24 @@ def sigusr1_handler(signum, frame): self.utility_results: dict[int, AnyFuture] = {} + @staticmethod + def _get_zmq_addresses(parallel_config: ParallelConfig) -> tuple[str, str]: + """Returns (input_address, output_address).""" + dp_size = parallel_config.data_parallel_size + local_engine_count = parallel_config.data_parallel_size_local + + if local_engine_count == dp_size: + input_address = get_open_zmq_ipc_path() + output_address = get_open_zmq_ipc_path() + else: + host = parallel_config.data_parallel_master_ip + input_port = parallel_config.data_parallel_rpc_port + output_port = get_open_port() + input_address = f"tcp://{host}:{input_port}" + output_address = f"tcp://{host}:{output_port}" + + return input_address, output_address + def _wait_for_engine_startup(self, output_address: str, parallel_config: ParallelConfig): # Get a sync handle to the socket which can be sync or async. @@ -402,60 +425,39 @@ def _wait_for_engine_startup(self, output_address: str, # TODO offline case compatibility # Wait for engine core process(es) to send ready messages. - local_engine_count = parallel_config.data_parallel_size_local - remote_engine_count = len(self.core_engines) - local_engine_count - - # TODO simplify the startup tracking logic below! - pending_hello_local = set(range(local_engine_count)) - pending_hello_remote = set( - range(local_engine_count, len(self.core_engines))) - pending_ready_local = set(pending_hello_local) - pending_ready_remote = set(pending_hello_remote) - while pending_ready_local or pending_ready_remote: + local_count = parallel_config.data_parallel_size_local + remote_count = len(self.core_engines) - local_count + # [local, remote] counts + conn_pending, start_pending = [local_count, remote_count], [0, 0] + + while any(conn_pending) or any(start_pending): while not sync_input_socket.poll(timeout=STARTUP_POLL_PERIOD_MS): - local_conn = local_engine_count - len(pending_hello_local) - local_ready = local_engine_count - len(pending_ready_local) - if local_ready != local_engine_count: + if any(conn_pending): logger.info( - "Waiting for local core engine procs: " - "%d/%d connected, %d/%d ready.", local_conn, - local_engine_count, local_ready, local_engine_count) - if remote_engine_count: - remote_conn = remote_engine_count - len( - pending_hello_remote) - remote_ready = remote_engine_count - len( - pending_ready_remote) - if remote_ready != remote_engine_count: - logger.info( - "Waiting for remote core engine procs: " - "%d/%d connected, %d/%d ready.", remote_conn, - remote_engine_count, remote_ready, - remote_engine_count) + "Waiting for %d local, %d remote core engine proc(s) " + "to connect.", *conn_pending) + if any(start_pending): + logger.info( + "Waiting for %d local, %d remote core engine proc(s) " + "to start.", *start_pending) # Receive HELLO and READY messages from the input socket. eng_identity, ready_msg_bytes = sync_input_socket.recv_multipart() eng_index = int.from_bytes(eng_identity, byteorder="little") + if eng_index > len(self.core_engines): + raise RuntimeError( + f"Message from engine rank larger than " + f"configured data parallel size: {eng_index}") + engine = self.core_engines[eng_index] msg = msgspec.msgpack.decode(ready_msg_bytes) status, local = msg["status"], msg["local"] - hello_set = pending_hello_local if local else pending_hello_remote - ready_set = pending_ready_local if local else pending_ready_remote - if status == "HELLO": - index_set = hello_set - elif status == "READY": - index_set = ready_set - else: - raise RuntimeError(f"{'Local' if local else 'Remote'} engine " - f"{eng_index} failed: {status}") - if eng_index not in index_set: - raise RuntimeError( - f"Unexpected or duplicate {status} " - f"{'local' if local else 'remote'} engine: {eng_index}") - if status == "READY" and eng_index in hello_set: - raise RuntimeError( - f"Unexpected READY before HELLO for " - f"{'local' if local else 'remote'} engine: {eng_index}") + if local != engine.local: + raise RuntimeError(f"{status} message from " + f"{'local' if local else 'remote'} " + f" engine {eng_index}, expected it to be " + f"{'local' if engine.local else 'remote'}") + if status == "HELLO" and engine.state == CoreEngineState.NEW: - if status == "HELLO": # Send init message with DP config info. init_message = self.encoder.encode({ "output_socket_address": output_address, @@ -470,10 +472,20 @@ def _wait_for_engine_startup(self, output_address: str, }) sync_input_socket.send_multipart((eng_identity, init_message), copy=False) + conn_pending[0 if local else 1] -= 1 + start_pending[0 if local else 1] += 1 + engine.state = CoreEngineState.CONNECTED + elif status == "READY" and (engine.state + == CoreEngineState.CONNECTED): + start_pending[0 if local else 1] -= 1 + engine.state = CoreEngineState.READY + else: + raise RuntimeError(f"Unexpected {status} message for " + f"{'local' if local else 'remote'} engine " + f"{eng_index} in {engine.state} state.") logger.debug("%s from %s core engine process %s.", status, "local" if local else "remote", eng_index) - index_set.discard(eng_index) # Double check that the process are running. engine_manager = self.resources.local_engine_manager From a6621696742fec492842fb5c8168f069070c48b1 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Fri, 4 Apr 2025 11:08:49 -0700 Subject: [PATCH 05/22] Fix offline DP compatibility Signed-off-by: Nick Hill --- vllm/config.py | 1 - vllm/entrypoints/cli/serve.py | 1 + vllm/v1/engine/core_client.py | 50 +++++++++++++++++++++++------------ vllm/v1/utils.py | 13 +++++---- 4 files changed, 42 insertions(+), 23 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 57ab0ef0596..9152d847a6f 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1546,7 +1546,6 @@ def __post_init__(self) -> None: if self.data_parallel_size > 1: # Data parallel was specified in the engine args. self.data_parallel_master_port = get_open_port() - # TODO multi-node else: # Otherwise fall back to env vars (e.g. for offline SPMD case). self.data_parallel_size = envs.VLLM_DP_SIZE diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index 801dd6db3d7..28362613dce 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -105,6 +105,7 @@ def run_headless(args: argparse.Namespace): target_fn=EngineCoreProc.run_engine_core, local_engine_count=local_engine_count, start_index=engine_args.data_parallel_start_rank, + local_start_index=0, vllm_config=vllm_config, on_head_node=False, input_address=input_address, diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index c5d9b16f8fe..144ea5bc9d6 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -363,11 +363,28 @@ def sigusr1_handler(signum, frame): self._finalizer = weakref.finalize(self, self.resources) parallel_config = vllm_config.parallel_config - dp_size = parallel_config.data_parallel_size local_engine_count = parallel_config.data_parallel_size_local + start_index = parallel_config.data_parallel_rank + local_start_index = parallel_config.data_parallel_rank_local + + # SPMD mode is where there is an LLM instance per DP rank and one + # core engine per LLM, see examples/offline_inference/data_parallel.py. + spmd_mode = local_start_index is not None + if spmd_mode: + assert local_engine_count == 1 + self.core_engines = [ + CoreEngine(index=local_start_index, local=True) + ] + else: + assert start_index == 0 + local_start_index = 0 + self.core_engines = [ + CoreEngine(index=i, local=(i < local_engine_count)) + for i in range(parallel_config.data_parallel_size) + ] input_address, output_address = self._get_zmq_addresses( - parallel_config) + parallel_config, spmd_mode) # Create input and output sockets. self.input_socket = self.resources.input_socket = make_zmq_socket( @@ -378,6 +395,7 @@ def sigusr1_handler(signum, frame): zmq.constants.PULL) # Start local engines. if local_engine_count: + # In server mode, start_index and local_start_index will both be 0. self.resources.local_engine_manager = CoreEngineProcManager( EngineCoreProc.run_engine_core, vllm_config=vllm_config, @@ -386,12 +404,9 @@ def sigusr1_handler(signum, frame): input_address=input_address, on_head_node=True, local_engine_count=local_engine_count, - start_index=0) + start_index=start_index, + local_start_index=local_start_index) - self.core_engines = [ - CoreEngine(index=i, local=(i < local_engine_count)) - for i in range(dp_size) - ] self.core_engine = self.core_engines[0] # Wait for engine core process(es) to start. @@ -400,12 +415,13 @@ def sigusr1_handler(signum, frame): self.utility_results: dict[int, AnyFuture] = {} @staticmethod - def _get_zmq_addresses(parallel_config: ParallelConfig) -> tuple[str, str]: + def _get_zmq_addresses(parallel_config: ParallelConfig, + spmd_mode: bool) -> tuple[str, str]: """Returns (input_address, output_address).""" dp_size = parallel_config.data_parallel_size local_engine_count = parallel_config.data_parallel_size_local - if local_engine_count == dp_size: + if local_engine_count == dp_size or spmd_mode: input_address = get_open_zmq_ipc_path() output_address = get_open_zmq_ipc_path() else: @@ -422,8 +438,6 @@ def _wait_for_engine_startup(self, output_address: str, # Get a sync handle to the socket which can be sync or async. sync_input_socket = zmq.Socket.shadow(self.input_socket) - # TODO offline case compatibility - # Wait for engine core process(es) to send ready messages. local_count = parallel_config.data_parallel_size_local remote_count = len(self.core_engines) - local_count @@ -444,18 +458,20 @@ def _wait_for_engine_startup(self, output_address: str, # Receive HELLO and READY messages from the input socket. eng_identity, ready_msg_bytes = sync_input_socket.recv_multipart() eng_index = int.from_bytes(eng_identity, byteorder="little") - if eng_index > len(self.core_engines): - raise RuntimeError( - f"Message from engine rank larger than " - f"configured data parallel size: {eng_index}") - engine = self.core_engines[eng_index] + engine = next( + (e for e in self.core_engines if e.identity == eng_identity), + None) + if engine is None: + raise RuntimeError(f"Message from engine with unexpected data " + f"parallel rank: {eng_index}") msg = msgspec.msgpack.decode(ready_msg_bytes) status, local = msg["status"], msg["local"] if local != engine.local: raise RuntimeError(f"{status} message from " f"{'local' if local else 'remote'} " - f" engine {eng_index}, expected it to be " + f"engine {eng_index}, expected it to be " f"{'local' if engine.local else 'remote'}") + if status == "HELLO" and engine.state == CoreEngineState.NEW: # Send init message with DP config info. diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index adfdb86a705..e6f947af4d8 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -105,6 +105,7 @@ def __init__( target_fn: Callable, local_engine_count: int, start_index: int, + local_start_index: int, vllm_config: VllmConfig, on_head_node: bool, input_address: str, @@ -121,14 +122,15 @@ def __init__( } self.processes = [] - for local_index in range(local_engine_count): - index = local_index + start_index + for index in range(local_engine_count): + local_index = local_start_index + index + global_index = start_index + index # Start EngineCore in background process. self.processes.append( context.Process(target=target_fn, - name=f"EngineCore_{index}", + name=f"EngineCore_{global_index}", kwargs=common_kwargs | { - "dp_rank": index, + "dp_rank": global_index, "local_dp_rank": local_index, })) @@ -172,7 +174,8 @@ def shutdown(procs: list[multiprocessing.Process], input_address: str): remaining = deadline - time.monotonic() if remaining <= 0: break - proc.join(remaining) + if proc.is_alive(): + proc.join(remaining) for proc in procs: if proc.is_alive(): From 8126f726c1d5f6d2308e23adf58cebbf0fc11399 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Mon, 7 Apr 2025 15:42:14 -0700 Subject: [PATCH 06/22] Address some review comments Signed-off-by: Nick Hill --- vllm/engine/arg_utils.py | 7 --- vllm/entrypoints/cli/serve.py | 9 +++- vllm/v1/engine/core.py | 84 +++++++++++++++++++---------------- 3 files changed, 54 insertions(+), 46 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 554638924ed..7b6cbae8b15 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -118,7 +118,6 @@ class EngineArgs: tensor_parallel_size: int = 1 data_parallel_size: int = 1 data_parallel_size_local: Optional[int] = None - data_parallel_start_rank: int = 0 data_parallel_address: Optional[str] = None data_parallel_rpc_port: Optional[int] = None enable_expert_parallel: bool = False @@ -450,12 +449,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=EngineArgs.data_parallel_size_local, help='Number of data parallel replicas to run on ' 'this node.') - parser.add_argument('--data-parallel-start-rank', - '-dpr', - type=int, - default=EngineArgs.data_parallel_start_rank, - help='Starting data parallel rank for secondary ' - 'nodes.') parser.add_argument('--data-parallel-address', '-dpa', type=str, diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index 0c2af151488..b9f64026d75 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -60,6 +60,13 @@ def subparser_init( default=False, help="Run in headless mode. See multi-node data parallel " "documentation for more details.") + serve_parser.add_argument( + '--data-parallel-start-rank', + '-dpr', + type=int, + default=0, + help='Starting data parallel rank for secondary ' + 'nodes.') serve_parser.add_argument( "--config", type=str, @@ -105,7 +112,7 @@ def run_headless(args: argparse.Namespace): engine_manager = CoreEngineProcManager( target_fn=EngineCoreProc.run_engine_core, local_engine_count=local_engine_count, - start_index=engine_args.data_parallel_start_rank, + start_index=args.data_parallel_start_rank, local_start_index=0, vllm_config=vllm_config, on_head_node=False, diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 1399604692e..d0668fc8df0 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -43,6 +43,7 @@ logger = init_logger(__name__) POLLING_TIMEOUT_S = 2.5 +HANDSHAKE_TIMEOUT_MINS = 5 _R = TypeVar('_R') # Return type for collective_rpc @@ -324,43 +325,47 @@ def __init__( zmq.DEALER, identity=identity, bind=False) - - # Register engine with front-end. - output_address = self.startup_handshake(input_socket, on_head_node, - vllm_config.parallel_config) - - # Set up data parallel environment. - self._init_data_parallel(vllm_config) - - # Initialize engine core and model. - super().__init__(vllm_config, executor_class, log_stats) - - self.step_fn = (self.step if self.batch_queue is None else - self.step_with_batch_queue) - - self.global_unfinished_reqs = False - - # Send ready message. - input_socket.send( - msgspec.msgpack.encode({ - "status": "READY", - "local": on_head_node - })) - - # Background Threads and Queues for IO. These enable us to - # overlap ZMQ socket IO with GPU since they release the GIL, - # and to overlap some serialization/deserialization with the - # model forward pass. - # Threads handle Socket <-> Queues and core_busy_loop uses Queue. - self.input_queue: queue.Queue[tuple[EngineCoreRequestType, - Any]] = queue.Queue() - self.output_queue: queue.Queue[EngineCoreOutputs] = queue.Queue() - threading.Thread(target=self.process_input_socket, - args=(input_socket, ), - daemon=True).start() - threading.Thread(target=self.process_output_socket, - args=(output_address, engine_index), - daemon=True).start() + try: + # Register engine with front-end. + output_address = self.startup_handshake( + input_socket, on_head_node, vllm_config.parallel_config) + + # Set up data parallel environment. + self._init_data_parallel(vllm_config) + + # Initialize engine core and model. + super().__init__(vllm_config, executor_class, log_stats) + + self.step_fn = (self.step if self.batch_queue is None else + self.step_with_batch_queue) + + self.global_unfinished_reqs = False + + # Send ready message. + input_socket.send( + msgspec.msgpack.encode({ + "status": "READY", + "local": on_head_node + })) + + # Background Threads and Queues for IO. These enable us to + # overlap ZMQ socket IO with GPU since they release the GIL, + # and to overlap some serialization/deserialization with the + # model forward pass. + # Threads handle Socket <-> Queues and core_busy_loop uses Queue. + self.input_queue: queue.Queue[tuple[EngineCoreRequestType, + Any]] = queue.Queue() + self.output_queue: queue.Queue[EngineCoreOutputs] = queue.Queue() + threading.Thread(target=self.process_input_socket, + args=(input_socket, ), + daemon=True).start() + input_socket = None + threading.Thread(target=self.process_output_socket, + args=(output_address, engine_index), + daemon=True).start() + finally: + if input_socket is not None: + input_socket.close(linger=0) @staticmethod def startup_handshake(input_socket: zmq.Socket, on_head_node: bool, @@ -375,7 +380,10 @@ def startup_handshake(input_socket: zmq.Socket, on_head_node: bool, # Receive initialization message. logger.info("Waiting for init message from front-end.") - input_socket.poll(timeout=5 * 60 * 1000) + if not input_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60 * 1000): + raise RuntimeError("Did not receive response from front-end " + f"process within {HANDSHAKE_TIMEOUT_MINS} " + f"minutes") init_bytes = input_socket.recv() init_message = msgspec.msgpack.decode(init_bytes) logger.debug("Received init message: %s", init_message) From 8fdc6f5c120051bd8894443e53f3970d9207aa78 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Mon, 7 Apr 2025 16:09:42 -0700 Subject: [PATCH 07/22] Address other minor review comments Signed-off-by: Nick Hill --- vllm/v1/engine/core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index d0668fc8df0..f52d010eefd 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -318,7 +318,7 @@ def __init__( engine_index: int = 0, ): # Create input socket. - input_ctx = zmq.Context() # type: ignore[attr-defined] + input_ctx = zmq.Context() identity = engine_index.to_bytes(length=2, byteorder="little") input_socket = make_zmq_socket(input_ctx, input_address, @@ -389,7 +389,7 @@ def startup_handshake(input_socket: zmq.Socket, on_head_node: bool, logger.debug("Received init message: %s", init_message) output_socket_address = init_message["output_socket_address"] - #TBD maybe replace IP with configured head node address + #TBD(nick) maybe replace IP with configured head node address received_parallel_config = init_message["parallel_config"] for key, value in received_parallel_config.items(): From efa8ad864370f2d2c0bb281354272bbc46469842 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 17 Apr 2025 12:19:21 -0700 Subject: [PATCH 08/22] Fix merge error, address @russellb's ipv6 review comment Signed-off-by: Nick Hill --- vllm/utils.py | 4 ++++ vllm/v1/engine/core_client.py | 8 ++++---- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/vllm/utils.py b/vllm/utils.py index c6e2afff72d..350dce8f02b 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -604,6 +604,10 @@ def is_valid_ipv6_address(address: str) -> bool: def get_distributed_init_method(ip: str, port: int) -> str: + return get_tcp_uri(ip, port) + + +def get_tcp_uri(ip: str, port: int) -> str: # Brackets are not permitted in ipv4 addresses, # see https://github.com/python/cpython/issues/103848 return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}" diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index bc2d5f07ebb..ba3bfbe6606 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -19,7 +19,7 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.utils import (get_open_port, get_open_zmq_inproc_path, - get_open_zmq_ipc_path, make_zmq_socket) + get_open_zmq_ipc_path, get_tcp_uri, make_zmq_socket) from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, EngineCoreRequestType, UtilityOutput) from vllm.v1.engine.core import EngineCore, EngineCoreProc @@ -423,8 +423,8 @@ def _get_zmq_addresses(parallel_config: ParallelConfig, host = parallel_config.data_parallel_master_ip input_port = parallel_config.data_parallel_rpc_port output_port = get_open_port() - input_address = f"tcp://{host}:{input_port}" - output_address = f"tcp://{host}:{output_port}" + input_address = get_tcp_uri(host, input_port) + output_address = get_tcp_uri(host, output_port) return input_address, output_address @@ -496,7 +496,7 @@ def _wait_for_engine_startup(self, output_address: str, parallel_config.data_parallel_size, }, }) - sync_input_socket.send_multipart((eng_identity, init_message), + sync_input_socket.send_multipart((eng_identity, *init_message), copy=False) conn_pending[0 if local else 1] -= 1 start_pending[0 if local else 1] += 1 From 30ab14b38791f7eaa8d4eec2e07756a8459c8282 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Fri, 18 Apr 2025 08:58:33 -0700 Subject: [PATCH 09/22] Hande ipv6 URIs in all places Signed-off-by: Nick Hill --- vllm/distributed/utils.py | 3 ++- vllm/entrypoints/cli/serve.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index 2cb57afd456..442a79bc716 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -21,6 +21,7 @@ import vllm.envs as envs from vllm.logger import init_logger +from vllm.utils import get_tcp_uri logger = init_logger(__name__) @@ -283,7 +284,7 @@ def stateless_init_torch_distributed_process_group( always formed with process 1, 2, ..., 8, and the additional communication channel is formed with process 9 and 10. """ - init_method = f"tcp://{host}:{port}" + init_method = get_tcp_uri(host, port) backend = Backend(backend) # it is basically string timeout = _get_default_timeout(backend) diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index b9f64026d75..4b3e134a485 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -12,7 +12,7 @@ validate_parsed_serve_args) from vllm.logger import init_logger from vllm.usage.usage_lib import UsageContext -from vllm.utils import FlexibleArgumentParser +from vllm.utils import FlexibleArgumentParser, get_tcp_uri from vllm.v1.engine.core import EngineCoreProc from vllm.v1.engine.core_client import CoreEngineProcManager from vllm.v1.executor.abstract import Executor @@ -98,7 +98,7 @@ def run_headless(args: argparse.Namespace): local_engine_count = parallel_config.data_parallel_size_local host = parallel_config.data_parallel_master_ip port = engine_args.data_parallel_rpc_port # add to config too - input_address = f"tcp://{host}:{port}" + input_address = get_tcp_uri(host, port) if local_engine_count <= 0: raise RuntimeError("data_parallel_size_local must be > 0 in " From acc5af341fdbf91a7ae9b1d0b21526533fb52bf7 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Fri, 18 Apr 2025 19:10:31 -0700 Subject: [PATCH 10/22] Fix head node with no engines, don't require dp size on other nodes Signed-off-by: Nick Hill --- vllm/config.py | 14 ++++++++------ vllm/entrypoints/cli/serve.py | 13 +++++++++++-- vllm/v1/engine/core.py | 5 ++++- 3 files changed, 23 insertions(+), 9 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index d50987a77a8..6d4c5f1a77c 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1621,13 +1621,16 @@ class is dynamically inherited by the worker class. This is used to inject world_size: int = field(init=False) """world_size is TPxPP, it affects the number of workers we create.""" - world_size_across_dp: int = field(init=False) - """world_size_across_dp is TPxPPxDP, it is the size of the world - including data parallelism.""" rank: int = 0 """Global rank in distributed setup.""" + @property + def world_size_across_dp(self) -> int: + """world_size_across_dp is TPxPPxDP, it is the size of the world + including data parallelism.""" + return self.world_size * self.data_parallel_size + def get_next_dp_init_port(self) -> int: """ We might need to initialize process groups in multiple @@ -1680,6 +1683,7 @@ def compute_hash(self): factors: list[Any] = [] factors.append(self.pipeline_parallel_size) factors.append(self.tensor_parallel_size) + factors.append(self.data_parallel_size) return hashlib.sha256(str(factors).encode()).hexdigest() def __post_init__(self) -> None: @@ -1690,7 +1694,7 @@ def __post_init__(self) -> None: raise ValueError( "data_parallel_size_local must be <= data_parallel_size") - if self.data_parallel_size > 1: + if self.data_parallel_size > 1 or self.data_parallel_size_local == 0: # Data parallel was specified in the engine args. self.data_parallel_master_port = get_open_port() else: @@ -1701,8 +1705,6 @@ def __post_init__(self) -> None: self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT - self.world_size_across_dp = self.world_size * self.data_parallel_size - if self.distributed_executor_backend == "external_launcher": import os os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index 4b3e134a485..04be7c03399 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import argparse +import signal import uvloop @@ -65,8 +66,7 @@ def subparser_init( '-dpr', type=int, default=0, - help='Starting data parallel rank for secondary ' - 'nodes.') + help='Starting data parallel rank for secondary nodes.') serve_parser.add_argument( "--config", type=str, @@ -104,6 +104,14 @@ def run_headless(args: argparse.Namespace): raise RuntimeError("data_parallel_size_local must be > 0 in " "headless mode") + # Catch SIGTERM and SIGINT to allow graceful shutdown. + def signal_handler(signum, frame): + logger.debug("Received %d signal.", signum) + raise SystemExit + + signal.signal(signal.SIGTERM, signal_handler) + signal.signal(signal.SIGINT, signal_handler) + logger.info( "Launching %d data parallel engine(s) in headless mode, " "with head node address %s.", local_engine_count, input_address) @@ -124,4 +132,5 @@ def run_headless(args: argparse.Namespace): try: engine_manager.join_first() finally: + logger.info("Shutting down.") engine_manager.close() diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 47e1b70cb04..b218eda8418 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -338,6 +338,9 @@ def __init__( output_address = self.startup_handshake( input_socket, on_head_node, vllm_config.parallel_config) + # Update config which may have changed from the handshake. + vllm_config.__post_init__() + # Set up data parallel environment. self._init_data_parallel(vllm_config) @@ -436,7 +439,7 @@ def signal_handler(signum, frame): try: parallel_config: ParallelConfig = kwargs[ "vllm_config"].parallel_config - if parallel_config.data_parallel_size > 1: + if parallel_config.data_parallel_size > 1 or dp_rank > 0: # Set data parallel rank for this engine process. parallel_config.data_parallel_rank = dp_rank parallel_config.data_parallel_rank_local = local_dp_rank From c76e8e5b278910badecf0496a4ac5bcb59354d73 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Fri, 4 Apr 2025 16:54:23 -0700 Subject: [PATCH 11/22] [Perf] API-server scaleout with all-to-all server-engine comms Signed-off-by: Nick Hill --- tests/v1/core/test_kv_cache_utils.py | 1 - tests/v1/core/test_prefix_caching.py | 1 - tests/v1/core/test_scheduler.py | 3 +- vllm/entrypoints/cli/serve.py | 159 +++++++++- vllm/entrypoints/openai/api_server.py | 75 +++-- vllm/utils.py | 6 +- vllm/v1/core/sched/interface.py | 10 +- vllm/v1/core/sched/scheduler.py | 52 +++- vllm/v1/engine/__init__.py | 6 +- vllm/v1/engine/async_llm.py | 11 +- vllm/v1/engine/coordinator.py | 207 ++++++++++++ vllm/v1/engine/core.py | 238 +++++++++----- vllm/v1/engine/core_client.py | 433 +++++++++++++------------- vllm/v1/metrics/loggers.py | 39 +-- vllm/v1/request.py | 5 +- vllm/v1/utils.py | 155 ++++++++- 16 files changed, 1018 insertions(+), 383 deletions(-) create mode 100644 vllm/v1/engine/coordinator.py diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index e8069b8c6d7..f197026e40c 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -44,7 +44,6 @@ def make_request(request_id, multi_modal_placeholders=mm_positions, sampling_params=SamplingParams(max_tokens=17), eos_token_id=100, - arrival_time=0, lora_request=None, cache_salt=cache_salt, ) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 4c05e0b87fc..5343059fbeb 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -38,7 +38,6 @@ def make_request(request_id, sampling_params=SamplingParams(max_tokens=17, prompt_logprobs=prompt_logprobs), eos_token_id=100, - arrival_time=0, lora_request=None, cache_salt=cache_salt, ) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index ee4e95856f2..547ce49b00e 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -138,7 +138,6 @@ def create_requests(num_requests: int, multi_modal_placeholders=mm_position, multi_modal_hashes=None, eos_token_id=EOS_TOKEN_ID, - arrival_time=0, ) requests.append(request) return requests @@ -732,7 +731,7 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected): prompt_logprobs_dict={}, ) engine_core_outputs = scheduler.update_from_output(output, - model_runner_output) + model_runner_output)[0] for i in range(len(requests)): running_req = scheduler.running[i] diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index 04be7c03399..5c2f94e08ed 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -1,22 +1,33 @@ # SPDX-License-Identifier: Apache-2.0 import argparse +import multiprocessing +import os import signal +import sys +from multiprocessing.context import SpawnProcess +from typing import Any import uvloop +import zmq import vllm.envs as envs from vllm import AsyncEngineArgs from vllm.entrypoints.cli.types import CLISubcommand -from vllm.entrypoints.openai.api_server import run_server +from vllm.entrypoints.openai.api_server import (run_server, run_server_worker, + setup_server) from vllm.entrypoints.openai.cli_args import (make_arg_parser, validate_parsed_serve_args) +from vllm.executor.multiproc_worker_utils import _add_prefix from vllm.logger import init_logger from vllm.usage.usage_lib import UsageContext -from vllm.utils import FlexibleArgumentParser, get_tcp_uri +from vllm.utils import FlexibleArgumentParser, get_tcp_uri, zmq_socket_ctx +from vllm.v1.engine.coordinator import DPCoordinator from vllm.v1.engine.core import EngineCoreProc from vllm.v1.engine.core_client import CoreEngineProcManager from vllm.v1.executor.abstract import Executor +from vllm.v1.utils import (CoreEngine, get_engine_client_zmq_addr, + wait_for_engine_startup) logger = init_logger(__name__) @@ -34,9 +45,12 @@ def cmd(args: argparse.Namespace) -> None: if hasattr(args, 'model_tag') and args.model_tag is not None: args.model = args.model_tag - if args.headless: + if args.headless or args.api_server_count < 1: run_headless(args) + elif args.api_server_count > 1: + run_multi_api_server(args) else: + # Single API server (this process). uvloop.run(run_server(args)) def validate(self, args: argparse.Namespace) -> None: @@ -67,6 +81,11 @@ def subparser_init( type=int, default=0, help='Starting data parallel rank for secondary nodes.') + serve_parser.add_argument('--api-server-count', + '-asc', + type=int, + default=1, + help='How many API server processes to run.') serve_parser.add_argument( "--config", type=str, @@ -86,6 +105,9 @@ def cmd_init() -> list[CLISubcommand]: def run_headless(args: argparse.Namespace): + if args.api_server_count > 1: + raise RuntimeError("api_server_count can't be set in headless mode") + # Create the EngineConfig. engine_args = AsyncEngineArgs.from_cli_args(args) usage_context = UsageContext.OPENAI_API_SERVER @@ -98,7 +120,7 @@ def run_headless(args: argparse.Namespace): local_engine_count = parallel_config.data_parallel_size_local host = parallel_config.data_parallel_master_ip port = engine_args.data_parallel_rpc_port # add to config too - input_address = get_tcp_uri(host, port) + handshake_address = get_tcp_uri(host, port) if local_engine_count <= 0: raise RuntimeError("data_parallel_size_local must be > 0 in " @@ -114,7 +136,7 @@ def signal_handler(signum, frame): logger.info( "Launching %d data parallel engine(s) in headless mode, " - "with head node address %s.", local_engine_count, input_address) + "with head node address %s.", local_engine_count, handshake_address) # Create the engines. engine_manager = CoreEngineProcManager( @@ -124,7 +146,7 @@ def signal_handler(signum, frame): local_start_index=0, vllm_config=vllm_config, on_head_node=False, - input_address=input_address, + handshake_address=handshake_address, executor_class=Executor.get_class(vllm_config), log_stats=not engine_args.disable_log_stats, ) @@ -134,3 +156,128 @@ def signal_handler(signum, frame): finally: logger.info("Shutting down.") engine_manager.close() + + +def run_multi_api_server(args: argparse.Namespace): + + assert not args.headless + num_api_servers = args.api_server_count + # assert num_api_servers > 1 + + listen_address, sock = setup_server(args) + + engine_args = AsyncEngineArgs.from_cli_args(args) + usage_context = UsageContext.OPENAI_API_SERVER + vllm_config = engine_args.create_engine_config(usage_context=usage_context) + parallel_config = vllm_config.parallel_config + + assert parallel_config.data_parallel_rank == 0 + + dp_size = parallel_config.data_parallel_size + local_engine_count = parallel_config.data_parallel_size_local + host = parallel_config.data_parallel_master_ip + local_only = local_engine_count == dp_size + + # Set up input and output addresses. + input_addresses = [ + get_engine_client_zmq_addr(local_only, host) + for _ in range(num_api_servers) + ] + output_addresses = [ + get_engine_client_zmq_addr(local_only, host) + for _ in range(num_api_servers) + ] + + addresses: dict[str, Any] = { + "input_addresses": input_addresses, + "output_addresses": output_addresses, + } + + # Set up coordinator for dp > 1. + coordinator = None + stats_update_address = None + if dp_size > 1: + # TODO "ready" event for coordinator + coordinator = DPCoordinator(parallel_config) + addresses.update(coordinator.get_engine_socket_addresses()) + stats_update_address = coordinator.get_stats_publish_address() + + handshake_address = get_engine_client_zmq_addr( + local_only, host, parallel_config.data_parallel_rpc_port) + + with zmq_socket_ctx(handshake_address, zmq.ROUTER, + bind=True) as handshake_socket: + + # Start local engines. + if not local_engine_count: + local_engine_manager = None + else: + local_engine_manager = CoreEngineProcManager( + EngineCoreProc.run_engine_core, + vllm_config=vllm_config, + executor_class=Executor.get_class(vllm_config), + log_stats=not engine_args.disable_log_stats, + handshake_address=handshake_address, + on_head_node=True, + local_engine_count=local_engine_count, + start_index=0, + local_start_index=0) + + # Start API servers. + spawn_context = multiprocessing.get_context("spawn") + api_server_workers: list[SpawnProcess] = [] + for i, in_addr, out_addr in zip(range(num_api_servers), + input_addresses, output_addresses): + client_config = { + "input_address": in_addr, + "output_address": out_addr, + "client_index": i + } + if stats_update_address is not None: + client_config["stats_update_address"] = stats_update_address + + # TODO check signal propagation + proc = spawn_context.Process(target=run_api_server_worker, + name=f"ApiServer_{i}", + args=(listen_address, sock, args, + client_config)) + api_server_workers.append(proc) + proc.start() + + # Wait for engine handshakes to complete. + core_engines = [ + CoreEngine(index=i, local=(i < local_engine_count)) + for i in range(dp_size) + ] + + wait_for_engine_startup( + handshake_socket, + addresses, + core_engines, + parallel_config, + vllm_config.cache_config, + local_engine_manager, + coordinator.proc if coordinator else None, + ) + + # TODO handle failures / clean shutdown here + for proc in api_server_workers: + proc.join() + + +def run_api_server_worker(listen_address, + sock, + args, + client_config=None, + **uvicorn_kwargs) -> None: + + # Add process-specific prefix to stdout and stderr. + from multiprocessing import current_process + process_name = current_process().name + pid = os.getpid() + _add_prefix(sys.stdout, process_name, pid) + _add_prefix(sys.stderr, process_name, pid) + + uvloop.run( + run_server_worker(listen_address, sock, args, client_config, + **uvicorn_kwargs)) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 9746d9697a6..604e0df9f2d 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -17,7 +17,7 @@ from contextlib import asynccontextmanager from functools import partial from http import HTTPStatus -from typing import Annotated, Optional, Union +from typing import Annotated, Any, Optional, Union import uvloop from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request @@ -137,14 +137,17 @@ async def _force_log(): @asynccontextmanager async def build_async_engine_client( - args: Namespace) -> AsyncIterator[EngineClient]: + args: Namespace, + client_config: Optional[dict[str, Any]] = None, +) -> AsyncIterator[EngineClient]: # Context manager to handle engine_client lifecycle # Ensures everything is shutdown and cleaned up on error/exit engine_args = AsyncEngineArgs.from_cli_args(args) async with build_async_engine_client_from_engine_args( - engine_args, args.disable_frontend_multiprocessing) as engine: + engine_args, args.disable_frontend_multiprocessing, + client_config) as engine: yield engine @@ -152,6 +155,7 @@ async def build_async_engine_client( async def build_async_engine_client_from_engine_args( engine_args: AsyncEngineArgs, disable_frontend_multiprocessing: bool = False, + client_config: Optional[dict[str, Any]] = None, ) -> AsyncIterator[EngineClient]: """ Create EngineClient, either: @@ -174,12 +178,16 @@ async def build_async_engine_client_from_engine_args( from vllm.v1.engine.async_llm import AsyncLLM async_llm: Optional[AsyncLLM] = None + client_index = client_config.pop( + "client_index") if client_config else 0 try: async_llm = AsyncLLM.from_vllm_config( vllm_config=vllm_config, usage_context=usage_context, disable_log_requests=engine_args.disable_log_requests, - disable_log_stats=engine_args.disable_log_stats) + disable_log_stats=engine_args.disable_log_stats, + client_addresses=client_config, + client_index=client_index) yield async_llm finally: if async_llm: @@ -1038,16 +1046,10 @@ def create_server_socket(addr: tuple[str, int]) -> socket.socket: return sock -async def run_server(args, **uvicorn_kwargs) -> None: - logger.info("vLLM API server version %s", VLLM_VERSION) - logger.info("args: %s", args) - - if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: - ToolParserManager.import_tool_parser(args.tool_parser_plugin) - +def validate_api_server_args(args): valid_tool_parses = ToolParserManager.tool_parsers.keys() if args.enable_auto_tool_choice \ - and args.tool_call_parser not in valid_tool_parses: + and args.tool_call_parser not in valid_tool_parses: raise KeyError(f"invalid tool call parser: {args.tool_call_parser} " f"(chose from {{ {','.join(valid_tool_parses)} }})") @@ -1058,6 +1060,16 @@ async def run_server(args, **uvicorn_kwargs) -> None: f"invalid reasoning parser: {args.reasoning_parser} " f"(chose from {{ {','.join(valid_reasoning_parses)} }})") + +def setup_server(args): + logger.info("vLLM API server version %s", VLLM_VERSION) + logger.info("args: %s", args) + + if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: + ToolParserManager.import_tool_parser(args.tool_parser_plugin) + + validate_api_server_args(args) + # workaround to make sure that we bind the port before the engine is set up. # This avoids race conditions with ray. # see https://github.com/vllm-project/vllm/issues/8204 @@ -1074,22 +1086,39 @@ def signal_handler(*_) -> None: signal.signal(signal.SIGTERM, signal_handler) - async with build_async_engine_client(args) as engine_client: + addr, port = sock_addr + is_ssl = args.ssl_keyfile and args.ssl_certfile + host_part = f"[{addr}]" if is_valid_ipv6_address( + addr) else addr or "0.0.0.0" + listen_address = f"http{'s' if is_ssl else ''}://{host_part}:{port}" + + return listen_address, sock + + +async def run_server(args, **uvicorn_kwargs) -> None: + listen_address, sock = setup_server(args) + await run_server_worker(listen_address, sock, args, **uvicorn_kwargs) + + +async def run_server_worker(listen_address, + sock, + args, + client_config=None, + **uvicorn_kwargs) -> None: + + if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: + ToolParserManager.import_tool_parser(args.tool_parser_plugin) + + server_index = client_config.get("client_index", 0) if client_config else 0 + + async with build_async_engine_client(args, client_config) as engine_client: app = build_app(args) vllm_config = await engine_client.get_vllm_config() await init_app_state(engine_client, vllm_config, app.state, args) - def _listen_addr(a: str) -> str: - if is_valid_ipv6_address(a): - return '[' + a + ']' - return a or "0.0.0.0" - - is_ssl = args.ssl_keyfile and args.ssl_certfile - logger.info("Starting vLLM API server on http%s://%s:%d", - "s" if is_ssl else "", _listen_addr(sock_addr[0]), - sock_addr[1]) - + logger.info("Starting vLLM API server %d on %s", server_index, + listen_address) shutdown_task = await serve_http( app, sock=sock, diff --git a/vllm/utils.py b/vllm/utils.py index ae021672d5a..76f45bebb10 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -2311,6 +2311,7 @@ def make_zmq_socket( socket_type: Any, bind: Optional[bool] = None, identity: Optional[bytes] = None, + linger: Optional[int] = None, ) -> Union[zmq.Socket, zmq.asyncio.Socket]: # type: ignore[name-defined] """Make a ZMQ socket with the proper bind/connect semantics.""" @@ -2330,7 +2331,7 @@ def make_zmq_socket( buf_size = -1 # Use system default buffer size if bind is None: - bind = socket_type != zmq.PUSH + bind = socket_type not in (zmq.PUSH, zmq.SUB, zmq.XSUB) if socket_type in (zmq.PULL, zmq.DEALER, zmq.ROUTER): socket.setsockopt(zmq.RCVHWM, 0) @@ -2343,6 +2344,9 @@ def make_zmq_socket( if identity is not None: socket.setsockopt(zmq.IDENTITY, identity) + if linger is not None: + socket.setsockopt(zmq.LINGER, linger) + # Determine if the path is a TCP socket with an IPv6 address. # Enable IPv6 on the zmq socket if so. scheme, host, _ = split_zmq_path(path) diff --git a/vllm/v1/core/sched/interface.py b/vllm/v1/core/sched/interface.py index 0b328f51090..7945048dff2 100644 --- a/vllm/v1/core/sched/interface.py +++ b/vllm/v1/core/sched/interface.py @@ -44,7 +44,7 @@ def update_from_output( self, scheduler_output: "SchedulerOutput", model_runner_output: "ModelRunnerOutput", - ) -> "EngineCoreOutputs": + ) -> dict[int, "EngineCoreOutputs"]: """Update the scheduler state based on the model runner output. This method is called after the model runner has processed the scheduled @@ -54,7 +54,8 @@ def update_from_output( for each request. Returns: - A EngineCoreOutputs object containing the outputs for each request. + A dict of client index to EngineCoreOutputs object containing the + outputs for each request originating from that client. """ raise NotImplementedError @@ -125,6 +126,11 @@ def reset_prefix_cache(self) -> bool: """ raise NotImplementedError + @abstractmethod + def get_request_counts(self) -> tuple[int, int]: + """Returns (num_running_reqs, num_waiting_reqs).""" + raise NotImplementedError + @abstractmethod def make_stats(self) -> Optional["SchedulerStats"]: """Make a SchedulerStats object for logging. diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index ae7280a1470..02652aab401 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -57,7 +57,8 @@ def __init__( # request ids should be included in the EngineCoreOutputs returned # by update_from_outputs(). This is currently used in the multi-engine # case to track request lifetimes efficiently. - self.include_finished_set = include_finished_set + self.finished_req_ids_dict: Optional[dict[int, set[str]]] = ( + defaultdict(set) if include_finished_set else None) # Scheduling constraints. self.max_num_running_reqs = self.scheduler_config.max_num_seqs @@ -641,7 +642,7 @@ def update_from_output( self, scheduler_output: SchedulerOutput, model_runner_output: ModelRunnerOutput, - ) -> EngineCoreOutputs: + ) -> dict[int, EngineCoreOutputs]: sampled_token_ids = model_runner_output.sampled_token_ids spec_token_ids = model_runner_output.spec_token_ids logprobs = model_runner_output.logprobs @@ -649,7 +650,7 @@ def update_from_output( num_scheduled_tokens = scheduler_output.num_scheduled_tokens new_running: list[Request] = [] - outputs: list[EngineCoreOutput] = [] + outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list) spec_decoding_stats: Optional[SpecDecodingStats] = None # NOTE(woosuk): As len(self.running) can be up to 1K or more, the below @@ -743,7 +744,7 @@ def update_from_output( prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) if new_token_ids: # Add EngineCoreOutput for this Request. - outputs.append( + outputs[request.client_index].append( EngineCoreOutput( request_id=req_id, new_token_ids=new_token_ids, @@ -764,17 +765,35 @@ def update_from_output( self._cached_reqs_data[req_data.req_id].append(req_data) self.running = new_running - engine_core_outputs = EngineCoreOutputs( - outputs=outputs, - scheduler_stats=self.make_stats(spec_decoding_stats), - ) - if self.include_finished_set: - #TODO currently sending duplicates here, improve this - engine_core_outputs.finished_requests = ( - scheduler_output.finished_req_ids | self.finished_req_ids) + + engine_core_outputs = { + client_index: EngineCoreOutputs(outputs=outs) + for client_index, outs in outputs.items() + } + + finished_req_ids = self.finished_req_ids_dict + if finished_req_ids is not None: + # Include ids of requests that finished since last outputs + # were sent. + for client_index, finished_set in finished_req_ids.items(): + if (eco := engine_core_outputs.get(client_index)) is not None: + eco.finished_requests = finished_set + else: + engine_core_outputs[client_index] = EngineCoreOutputs( + finished_requests=finished_set) + finished_req_ids.clear() + + if engine_core_outputs: + # Return stats to only one of the front-ends. + next(iter(engine_core_outputs.values())).scheduler_stats = ( + self.make_stats(spec_decoding_stats)) return engine_core_outputs + def get_request_counts(self) -> tuple[int, int]: + """Returns (num_running_reqs, num_waiting_reqs).""" + return len(self.running), len(self.waiting) + def add_request(self, request: Request) -> None: self.waiting.append(request) self.requests[request.request_id] = request @@ -815,9 +834,12 @@ def _free_request(self, request: Request) -> None: self.kv_cache_manager.free(request) self.kv_cache_manager.free_block_hashes(request) self.encoder_cache_manager.free(request) - self._cached_reqs_data.pop(request.request_id, None) - del self.requests[request.request_id] - self.finished_req_ids.add(request.request_id) + request_id = request.request_id + self._cached_reqs_data.pop(request_id, None) + del self.requests[request_id] + self.finished_req_ids.add(request_id) + if self.finished_req_ids_dict is not None: + self.finished_req_ids_dict[request.client_index].add(request_id) def get_num_unfinished_requests(self) -> int: return len(self.waiting) + len(self.running) diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index e33d1a1e5dc..48d38da8e88 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -44,10 +44,6 @@ class EngineCoreRequest( omit_defaults=True, # type: ignore[call-arg] gc=False): # type: ignore[call-arg] - # NOTE: prompt and prompt_token_ids should be DecoderOnlyInput, - # but this object is currently not playing well with msgspec - # due to circular imports and typing we have in data.py - request_id: str prompt_token_ids: list[int] mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]] @@ -59,6 +55,8 @@ class EngineCoreRequest( lora_request: Optional[LoRARequest] cache_salt: Optional[str] + client_index: int = 0 + # Used in DP case to indicate which wave of requests this is expected to # belong to, to cover a race condition where the request is sent before # a wave finished notification is received. diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 14ce820cc39..31b844e3a68 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -52,6 +52,8 @@ def __init__( log_requests: bool = True, start_engine_loop: bool = True, stat_loggers: Optional[list[StatLoggerFactory]] = None, + client_addresses: Optional[dict[str, str]] = None, + client_index: int = 0, ) -> None: """ Create an AsyncLLM. @@ -119,6 +121,8 @@ def __init__( vllm_config=vllm_config, executor_class=executor_class, log_stats=self.log_stats, + client_addresses=client_addresses, + client_index=client_index, ) for stat_logger in self.stat_loggers[0]: stat_logger.log_engine_initialized() @@ -139,6 +143,8 @@ def from_vllm_config( stat_loggers: Optional[list[StatLoggerFactory]] = None, disable_log_requests: bool = False, disable_log_stats: bool = False, + client_addresses: Optional[dict[str, str]] = None, + client_index: int = 0, ) -> "AsyncLLM": if not envs.VLLM_USE_V1: raise ValueError( @@ -156,6 +162,8 @@ def from_vllm_config( log_requests=not disable_log_requests, log_stats=not disable_log_stats, usage_context=usage_context, + client_addresses=client_addresses, + client_index=client_index, ) @classmethod @@ -392,7 +400,6 @@ async def output_handler(): # TODO(rob): make into a coroutine and launch it in # background thread once Prometheus overhead is non-trivial. if stat_loggers: - assert outputs.scheduler_stats is not None AsyncLLM._record_stats( stat_loggers[outputs.engine_index], scheduler_stats=outputs.scheduler_stats, @@ -416,7 +423,7 @@ async def abort(self, request_id: str) -> None: @staticmethod def _record_stats( stat_loggers: list[StatLoggerBase], - scheduler_stats: SchedulerStats, + scheduler_stats: Optional[SchedulerStats], iteration_stats: Optional[IterationStats], ): """static so that it can be used from the output_handler task diff --git a/vllm/v1/engine/coordinator.py b/vllm/v1/engine/coordinator.py new file mode 100644 index 00000000000..d0421e00ade --- /dev/null +++ b/vllm/v1/engine/coordinator.py @@ -0,0 +1,207 @@ +# SPDX-License-Identifier: Apache-2.0 +import multiprocessing +import sys +import time +from typing import Optional + +import msgspec.msgpack +import zmq + +from vllm.config import ParallelConfig +from vllm.logger import init_logger +from vllm.utils import get_mp_context, get_open_zmq_ipc_path, make_zmq_socket +from vllm.v1.engine import EngineCoreOutputs, EngineCoreRequestType +from vllm.v1.serial_utils import MsgpackDecoder +from vllm.v1.utils import get_engine_client_zmq_addr + +logger = init_logger(__name__) + + +class DPCoordinator: + + def __init__(self, parallel_config: ParallelConfig): + + # Assume coordinator is colocated with front-end procs. + front_publish_address = get_open_zmq_ipc_path() + + dp_size = parallel_config.data_parallel_size + assert dp_size > 1, "Coordinator only used for data parallel" + + local_only = dp_size == parallel_config.data_parallel_size_local + host = parallel_config.data_parallel_master_ip + back_publish_address = get_engine_client_zmq_addr(local_only, host) + back_output_address = get_engine_client_zmq_addr(local_only, host) + + context = get_mp_context() + self.proc: multiprocessing.Process = context.Process( + target=CoordinatorProc.run_coordinator, + name="VLLM_DP_Coordinator", + kwargs={ + "engine_count": parallel_config.data_parallel_size, + "front_publish_address": front_publish_address, + "back_output_address": back_output_address, + "back_publish_address": back_publish_address, + }, + daemon=True) + self.proc.start() + + self.stats_publish_address = front_publish_address + self.coord_in_address = back_publish_address + self.coord_out_address = back_output_address + + def get_stats_publish_address(self) -> str: + return self.stats_publish_address + + def get_engine_socket_addresses(self) -> dict[str, str]: + return { + "coord_in_address": self.coord_in_address, + "coord_out_address": self.coord_out_address, + } + + def close(self): + self.proc.terminate() + + +class EngineState: + + def __init__(self): + self.request_counts = [0, 0] # [waiting, running] + + +class CoordinatorProc: + + def __init__(self, engine_count: int): + + self.ctx = zmq.Context() + + self.engines = [EngineState() for _ in range(engine_count)] + + self.current_wave = 0 + self.engines_running = False + self.stats_changed = False + + @staticmethod + def run_coordinator( + engine_count: int, + front_publish_address: str, + back_output_address: str, + back_publish_address: str, + ): + coordinator = CoordinatorProc(engine_count=engine_count) + + try: + coordinator.process_input_socket( + front_publish_address, + back_output_address, + back_publish_address, + ) + except KeyboardInterrupt: + logger.info("DP Coordinator process exiting") + + def process_input_socket(self, front_publish_address: str, + back_output_address: str, + back_publish_address: str): + + decoder = MsgpackDecoder(EngineCoreOutputs) + + with make_zmq_socket( + path=front_publish_address, # IPC + ctx=self.ctx, + socket_type=zmq.XPUB, + bind=True, + ) as publish_front, make_zmq_socket( + path=back_output_address, # IPC or TCP + ctx=self.ctx, + socket_type=zmq.PULL, + bind=True, + ) as output_back, make_zmq_socket( + path=back_publish_address, # IPC or TCP + ctx=self.ctx, + socket_type=zmq.XPUB, + bind=True, + ) as publish_back: + + poller = zmq.Poller() + poller.register(publish_front, zmq.POLLIN) + poller.register(output_back, zmq.POLLIN) + last_publish = 0 + while True: + elapsed = int(time.time() * 1000) - last_publish + wait_for = 100 if self.stats_changed else 3000 + events = poller.poll(timeout=max(0, wait_for - elapsed)) + if not events: + engine_list = self._get_engine_list() + to_publish = (engine_list, self.current_wave, + self.engines_running) + msg = msgspec.msgpack.encode(to_publish) + publish_front.send(msg) + last_publish = int(time.time() * 1000) + self.stats_changed = False + continue + + events = dict(events) + + if publish_front in events: + buffer = publish_front.recv() + if buffer == b'\x01': + # Ignore subscription messages. + continue + engine_index, wave = msgspec.msgpack.decode(buffer) + if wave < self.current_wave: + engine_index = None + if not self.engines_running: + self.engines_running = True + self.stats_changed = True + self._send_start_wave(publish_back, self.current_wave, + engine_index) + + if output_back in events: + buffer = output_back.recv() + outputs: EngineCoreOutputs = decoder.decode(buffer) + + assert not outputs.outputs + assert outputs.utility_output is None + + eng_index = outputs.engine_index + if outputs.scheduler_stats: + stats = self.engines[eng_index].request_counts + stats[0] = outputs.scheduler_stats.num_waiting_reqs + stats[1] = outputs.scheduler_stats.num_running_reqs + self.stats_changed = True + + #TODO record prometheus metrics here? + + if outputs.wave_complete is not None: + if self.current_wave <= wave: + self.current_wave = wave + 1 + self.engines_running = False + self.stats_changed = True + elif outputs.start_wave is not None and ( + wave > self.current_wave or + (wave == self.current_wave + and not self.engines_running)): + # Engine received request for a non-current wave so + # we must ensure that other engines progress to the + # next wave. + self.current_wave = wave + self.engines_running = True + self.stats_changed = True + self._send_start_wave(publish_back, wave, eng_index) + + @staticmethod + def _send_start_wave(socket: zmq.Socket, wave: int, + exclude_engine_index: Optional[int]): + wave_encoded = msgspec.msgpack.encode((wave, exclude_engine_index)) + socket.send_multipart( + (EngineCoreRequestType.START_DP_WAVE.value, wave_encoded)) + + def _get_engine_list(self) -> Optional[list[int]]: + shortlist: list[int] = [] + min_counts = [sys.maxsize, sys.maxsize] + for i, e in enumerate(self.engines): + if e.request_counts <= min_counts: + if e.request_counts < min_counts: + min_counts = e.request_counts + shortlist.clear() + shortlist.append(i) + return None if len(shortlist) == len(self.engines) else shortlist diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 76390da25a2..b8dbb1f261e 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -7,6 +7,7 @@ import time from collections import deque from concurrent.futures import Future +from contextlib import ExitStack from inspect import isclass, signature from logging import DEBUG from typing import Any, Callable, Optional, TypeVar, Union @@ -21,7 +22,7 @@ from vllm.lora.request import LoRARequest from vllm.transformers_utils.config import ( maybe_register_config_serialize_by_value) -from vllm.utils import make_zmq_socket, resolve_obj_by_qualname, zmq_socket_ctx +from vllm.utils import make_zmq_socket, resolve_obj_by_qualname from vllm.v1.core.kv_cache_utils import (get_kv_cache_config, unify_kv_cache_configs) from vllm.v1.core.sched.interface import SchedulerInterface @@ -32,6 +33,7 @@ from vllm.v1.engine.mm_input_cache import MirroredProcessingCache from vllm.v1.executor.abstract import Executor from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder @@ -191,16 +193,13 @@ def abort_requests(self, request_ids: list[str]): self.scheduler.finish_requests(request_ids, RequestStatus.FINISHED_ABORTED) - def step(self) -> EngineCoreOutputs: + def step(self) -> dict[int, EngineCoreOutputs]: """Schedule, execute, and make output.""" # Check for any requests remaining in the scheduler - unfinished, # or finished and not yet removed from the batch. if not self.scheduler.has_requests(): - return EngineCoreOutputs( - outputs=[], - scheduler_stats=self.scheduler.make_stats(), - ) + return {} scheduler_output = self.scheduler.schedule() output = self.model_executor.execute_model(scheduler_output) engine_core_outputs = self.scheduler.update_from_output( @@ -208,7 +207,7 @@ def step(self) -> EngineCoreOutputs: return engine_core_outputs - def step_with_batch_queue(self) -> Optional[EngineCoreOutputs]: + def step_with_batch_queue(self) -> Optional[dict[int, EngineCoreOutputs]]: """Schedule and execute batches with the batch queue. Note that if nothing to output in this step, None is returned. @@ -250,8 +249,8 @@ def step_with_batch_queue(self) -> Optional[EngineCoreOutputs]: # Blocking until the first result is available. model_output = future.result() self.batch_queue.task_done() - engine_core_outputs = self.scheduler.update_from_output( - scheduler_output, model_output) + engine_core_outputs = (self.scheduler.update_from_output( + scheduler_output, model_output)) return engine_core_outputs @@ -320,7 +319,7 @@ def __init__( self, vllm_config: VllmConfig, on_head_node: bool, - input_address: str, + handshake_address: str, executor_class: type[Executor], log_stats: bool, engine_index: int = 0, @@ -333,15 +332,22 @@ def __init__( # Create input socket. input_ctx = zmq.Context() identity = engine_index.to_bytes(length=2, byteorder="little") - input_socket = make_zmq_socket(input_ctx, - input_address, - zmq.DEALER, - identity=identity, - bind=False) - try: + with make_zmq_socket(input_ctx, + handshake_address, + zmq.DEALER, + identity=identity, + linger=5000, + bind=False) as handshake_socket: + # Register engine with front-end. - output_address = self.startup_handshake( - input_socket, on_head_node, vllm_config.parallel_config) + addresses = self.startup_handshake(handshake_socket, on_head_node, + vllm_config.parallel_config) + input_addresses: list[str] = addresses["input_addresses"] + output_addresses: list[str] = addresses["output_addresses"] + coord_in_addr: Optional[str] = addresses.get("coord_in_address") + coord_out_addr: Optional[str] = addresses.get("coord_out_address") + self.client_count = len(output_addresses) + self.coordinator = coord_out_addr is not None # Update config which may have changed from the handshake. vllm_config.__post_init__() @@ -353,42 +359,41 @@ def __init__( super().__init__(vllm_config, executor_class, log_stats, executor_fail_callback) + self.engine_index = engine_index self.step_fn = (self.step if self.batch_queue is None else self.step_with_batch_queue) self.engines_running = False + self.last_counts = (0, 0) # Send ready message. num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks - input_socket.send( + handshake_socket.send( msgspec.msgpack.encode({ "status": "READY", "local": on_head_node, "num_gpu_blocks": num_gpu_blocks, })) - # Background Threads and Queues for IO. These enable us to - # overlap ZMQ socket IO with GPU since they release the GIL, - # and to overlap some serialization/deserialization with the - # model forward pass. - # Threads handle Socket <-> Queues and core_busy_loop uses Queue. - self.input_queue = input_queue - self.output_queue = queue.Queue[Union[EngineCoreOutputs, bytes]]() - threading.Thread(target=self.process_input_socket, - args=(input_socket, ), - daemon=True).start() - input_socket = None - self.output_thread = threading.Thread( - target=self.process_output_socket, - args=(output_address, engine_index), - daemon=True) - self.output_thread.start() - finally: - if input_socket is not None: - input_socket.close(linger=0) + # Background Threads and Queues for IO. These enable us to + # overlap ZMQ socket IO with GPU since they release the GIL, + # and to overlap some serialization/deserialization with the + # model forward pass. + # Threads handle Socket <-> Queues and core_busy_loop uses Queue. + self.input_queue = input_queue + self.output_queue = queue.Queue[Union[tuple[int, EngineCoreOutputs], + bytes]]() + threading.Thread(target=self.process_input_sockets, + args=(input_addresses, coord_in_addr, identity), + daemon=True).start() + self.output_thread = threading.Thread( + target=self.process_output_sockets, + args=(output_addresses, coord_out_addr, engine_index), + daemon=True) + self.output_thread.start() @staticmethod def startup_handshake(input_socket: zmq.Socket, on_head_node: bool, - parallel_config: ParallelConfig) -> str: + parallel_config: ParallelConfig) -> dict[str, Any]: # Send registration message. input_socket.send( @@ -407,14 +412,11 @@ def startup_handshake(input_socket: zmq.Socket, on_head_node: bool, init_message = msgspec.msgpack.decode(init_bytes) logger.debug("Received init message: %s", init_message) - output_socket_address = init_message["output_socket_address"] - #TBD(nick) maybe replace IP with configured head node address - - received_parallel_config = init_message["parallel_config"] + received_parallel_config = init_message.pop("parallel_config") for key, value in received_parallel_config.items(): setattr(parallel_config, key, value) - return output_socket_address + return init_message["addresses"] @staticmethod def run_engine_core(*args, @@ -506,9 +508,22 @@ def _process_engine_step(self): # Step the engine core. outputs = self.step_fn() + if not outputs: + return + # Put EngineCoreOutputs into the output queue. - if outputs is not None: - self.output_queue.put_nowait(outputs) + for output in outputs.items(): + self.output_queue.put_nowait(output) + + if self.coordinator: + # If there is a DP coordinator, publish our request counts + # (if they've changed) + counts = self.scheduler.get_request_counts() + if counts != self.last_counts: + self.last_counts = counts + stats = SchedulerStats(*counts) + self.output_queue.put_nowait( + (-1, EngineCoreOutputs(scheduler_stats=stats))) def _handle_client_request(self, request_type: EngineCoreRequestType, request: Any) -> None: @@ -519,7 +534,7 @@ def _handle_client_request(self, request_type: EngineCoreRequestType, elif request_type == EngineCoreRequestType.ABORT: self.abort_requests(request) elif request_type == EngineCoreRequestType.UTILITY: - call_id, method_name, args = request + client_idx, call_id, method_name, args = request output = UtilityOutput(call_id) try: method = getattr(self, method_name) @@ -530,7 +545,7 @@ def _handle_client_request(self, request_type: EngineCoreRequestType, output.failure_message = (f"Call to {method_name} method" f" failed: {str(e)}") self.output_queue.put_nowait( - EngineCoreOutputs(utility_output=output)) + (client_idx, EngineCoreOutputs(utility_output=output))) elif request_type == EngineCoreRequestType.EXECUTOR_FAILED: raise RuntimeError("Executor failed.") else: @@ -563,27 +578,68 @@ def _send_engine_dead(self): logger.fatal("vLLM shutdown signal from EngineCore failed " "to send. Please report this issue.") - def process_input_socket(self, input_socket: zmq.Socket): + def process_input_sockets(self, input_addresses: list[str], + coord_input_address: Optional[str], + identity: bytes): """Input socket IO thread.""" # Msgpack serialization decoding. add_request_decoder = MsgpackDecoder(EngineCoreRequest) generic_decoder = MsgpackDecoder() - while True: - # (RequestType, RequestData) - type_frame, *data_frames = input_socket.recv_multipart(copy=False) - request_type = EngineCoreRequestType(bytes(type_frame.buffer)) - - # Deserialize the request data. - decoder = add_request_decoder if ( - request_type == EngineCoreRequestType.ADD) else generic_decoder - request = decoder.decode(data_frames) - - # Push to input queue for core busy loop. - self.input_queue.put_nowait((request_type, request)) + with ExitStack() as stack, zmq.Context() as ctx: + input_sockets = [ + stack.enter_context( + make_zmq_socket(ctx, + input_address, + zmq.DEALER, + identity=identity, + bind=False)) + for input_address in input_addresses + ] + if coord_input_address is None: + coord_socket = None + else: + coord_socket = stack.enter_context( + make_zmq_socket(ctx, + coord_input_address, + zmq.XSUB, + identity=identity, + bind=False)) + # Send subscription message to coordinator. + coord_socket.send(b'\x01') + + # Register sockets with poller. + poller = zmq.Poller() + for input_socket in input_sockets: + # Send initial message to each input socket - this is required + # before the front-end ROUTER socket can send input messages + # back to us. + input_socket.send(b'') + poller.register(input_socket, zmq.POLLIN) + if coord_socket is not None: + poller.register(coord_socket, zmq.POLLIN) - def process_output_socket(self, output_path: str, engine_index: int): + while True: + for input_socket, _ in poller.poll(): + # (RequestType, RequestData) + type_frame, *data_frames = input_socket.recv_multipart( + copy=False) + request_type = EngineCoreRequestType( + bytes(type_frame.buffer)) + + # Deserialize the request data. + decoder = add_request_decoder if ( + request_type + == EngineCoreRequestType.ADD) else generic_decoder + request = decoder.decode(data_frames) + + # Push to input queue for core busy loop. + self.input_queue.put_nowait((request_type, request)) + + def process_output_sockets(self, output_paths: list[str], + coord_output_path: Optional[str], + engine_index: int): """Output socket IO thread.""" # Msgpack serialization encoding. @@ -597,30 +653,50 @@ def process_output_socket(self, output_path: str, engine_index: int): # We must set linger to ensure the ENGINE_CORE_DEAD # message is sent prior to closing the socket. - with zmq_socket_ctx(output_path, zmq.constants.PUSH, - linger=4000) as socket: + with ExitStack() as stack, zmq.Context() as ctx: + sockets = [ + stack.enter_context( + make_zmq_socket(ctx, output_path, zmq.PUSH, linger=4000)) + for output_path in output_paths + ] + coord_socket = stack.enter_context( + make_zmq_socket( + ctx, coord_output_path, zmq.PUSH, bind=False, + linger=4000)) if coord_output_path is not None else None + max_reuse_bufs = len(sockets) + 1 + while True: - outputs = self.output_queue.get() - if outputs == EngineCoreProc.ENGINE_CORE_DEAD: - socket.send(outputs, copy=False) + output = self.output_queue.get() + if output == EngineCoreProc.ENGINE_CORE_DEAD: + for socket in sockets: + socket.send(output) + #TODO also send to coordinator here? break - assert not isinstance(outputs, bytes) + assert not isinstance(output, bytes) + client_index, outputs = output outputs.engine_index = engine_index + if client_index == -1: + # Don't reuse buffer for coordinator message + # which will be very small. + assert coord_socket is not None + coord_socket.send_multipart(encoder.encode(outputs)) + continue + # Reclaim buffers that zmq is finished with. while pending and pending[-1][0].done: reuse_buffers.append(pending.pop()[2]) buffer = reuse_buffers.pop() if reuse_buffers else bytearray() buffers = encoder.encode_into(outputs, buffer) - tracker = socket.send_multipart(buffers, - copy=False, - track=True) + tracker = sockets[client_index].send_multipart(buffers, + copy=False, + track=True) if not tracker.done: ref = outputs if len(buffers) > 1 else None pending.appendleft((tracker, ref, buffer)) - elif len(reuse_buffers) < 2: - # Keep at most 2 buffers to reuse. + elif len(reuse_buffers) < max_reuse_bufs: + # Limit the number of buffers to reuse. reuse_buffers.append(buffer) @@ -632,7 +708,7 @@ def __init__( self, vllm_config: VllmConfig, on_head_node: bool, - input_address: str, + handshake_address: str, executor_class: type[Executor], log_stats: bool, ): @@ -647,10 +723,11 @@ def __init__( # Counts forward-passes of the model so that we can synchronize # finished with DP peers every N steps. self.counter = 0 + self.current_wave = 0 # Initialize the engine. dp_rank = vllm_config.parallel_config.data_parallel_rank - super().__init__(vllm_config, on_head_node, input_address, + super().__init__(vllm_config, on_head_node, handshake_address, executor_class, log_stats, dp_rank) def _init_data_parallel(self, vllm_config: VllmConfig): @@ -674,7 +751,6 @@ def _init_data_parallel(self, vllm_config: VllmConfig): self.local_dp_rank = local_dp_rank self.dp_group = vllm_config.parallel_config.stateless_init_dp_group() - self.current_wave = 0 def shutdown(self): super().shutdown() @@ -689,15 +765,16 @@ def add_request(self, request: EngineCoreRequest): # Request received for an already-completed wave, notify # front-end that we need to start the next one. self.output_queue.put_nowait( - EngineCoreOutputs(start_wave=self.current_wave)) + (-1, EngineCoreOutputs(start_wave=self.current_wave))) super().add_request(request) def _handle_client_request(self, request_type: EngineCoreRequestType, request: Any) -> None: if request_type == EngineCoreRequestType.START_DP_WAVE: - new_wave: int = request - if new_wave >= self.current_wave: + new_wave, exclude_eng_index = request + if exclude_eng_index != self.engine_index and ( + new_wave >= self.current_wave): self.current_wave = new_wave if not self.engines_running: logger.debug("EngineCore starting idle loop for wave %d.", @@ -750,7 +827,8 @@ def run_busy_loop(self): logger.debug("Wave %d finished, pausing engine loop.", self.current_wave) self.output_queue.put_nowait( - EngineCoreOutputs(wave_complete=self.current_wave)) + (-1, + EngineCoreOutputs(wave_complete=self.current_wave))) self.current_wave += 1 def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool: diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 863b1024b7d..d019279b56d 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -9,26 +9,27 @@ from collections.abc import Awaitable, Sequence from concurrent.futures import Future from dataclasses import dataclass -from enum import Enum, auto from threading import Thread from typing import Any, Callable, Optional, TypeVar, Union -import msgspec +import msgspec.msgpack import zmq import zmq.asyncio -from vllm.config import ParallelConfig, VllmConfig +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.utils import (get_open_port, get_open_zmq_inproc_path, - get_open_zmq_ipc_path, get_tcp_uri, make_zmq_socket) +from vllm.utils import (get_open_zmq_inproc_path, make_zmq_socket, + zmq_socket_ctx) from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, EngineCoreRequestType, UtilityOutput) +from vllm.v1.engine.coordinator import DPCoordinator from vllm.v1.engine.core import EngineCore, EngineCoreProc from vllm.v1.engine.exceptions import EngineDeadError from vllm.v1.executor.abstract import Executor from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr -from vllm.v1.utils import CoreEngineProcManager +from vllm.v1.utils import (CoreEngine, CoreEngineProcManager, + get_engine_client_zmq_addr, wait_for_engine_startup) logger = init_logger(__name__) @@ -36,8 +37,6 @@ _R = TypeVar('_R') # Return type for collective_rpc -STARTUP_POLL_PERIOD_MS = 10000 - class EngineCoreClient(ABC): """ @@ -200,7 +199,7 @@ def __init__(self, *args, **kwargs): self.engine_core = EngineCore(*args, **kwargs) def get_output(self) -> EngineCoreOutputs: - return self.engine_core.step() + return self.engine_core.step().get(0) or EngineCoreOutputs() def add_request(self, request: EngineCoreRequest) -> None: self.engine_core.add_request(request) @@ -256,24 +255,6 @@ def collective_rpc(self, return self.engine_core.collective_rpc(method, timeout, args, kwargs) -class CoreEngineState(Enum): - NEW = auto() - CONNECTED = auto() - READY = auto() - - -class CoreEngine: - """One per data parallel rank.""" - - def __init__(self, index: int = 0, local: bool = True): - self.local = local - self.index = index - self.identity = index.to_bytes(length=2, byteorder="little") - - self.state = CoreEngineState.NEW - self.num_reqs_in_flight = 0 - - @dataclass class BackgroundResources: """Used as a finalizer for clean shutdown, avoiding @@ -281,9 +262,11 @@ class BackgroundResources: ctx: Union[zmq.Context] local_engine_manager: Optional[CoreEngineProcManager] = None + coordinator: Optional[DPCoordinator] = None output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None output_queue_task: Optional[asyncio.Task] = None + stats_update_task: Optional[asyncio.Task] = None shutdown_path: Optional[str] = None # Set if any of the engines are dead. Here so that the output @@ -296,9 +279,13 @@ def __call__(self): self.engine_dead = True if self.local_engine_manager is not None: self.local_engine_manager.close() + if self.coordinator is not None: + self.coordinator.close() if self.output_queue_task is not None: self.output_queue_task.cancel() + if self.stats_update_task is not None: + self.stats_update_task.cancel() # ZMQ context termination can hang if the sockets # aren't explicitly closed first. @@ -340,6 +327,7 @@ def __init__( vllm_config: VllmConfig, executor_class: type[Executor], log_stats: bool, + client_addresses: Optional[dict[str, str]] = None, ): self.vllm_config = vllm_config # Serialization setup. @@ -359,8 +347,8 @@ def __init__( try: parallel_config = vllm_config.parallel_config local_engine_count = parallel_config.data_parallel_size_local - start_index = parallel_config.data_parallel_rank local_start_index = parallel_config.data_parallel_rank_local + dp_size = parallel_config.data_parallel_size # SPMD mode is where there is an LLM instance per DP rank and # one core engine per LLM, see @@ -372,42 +360,53 @@ def __init__( CoreEngine(index=local_start_index, local=True) ] else: - assert start_index == 0 + assert parallel_config.data_parallel_rank == 0 local_start_index = 0 self.core_engines = [ CoreEngine(index=i, local=(i < local_engine_count)) - for i in range(parallel_config.data_parallel_size) + for i in range(dp_size) ] - input_address, output_address = self._get_zmq_addresses( - parallel_config, spmd_mode) + local_only = spmd_mode or local_engine_count == dp_size + + self.stats_update_address: Optional[str] = None + if client_addresses is not None: + input_address = client_addresses["input_address"] + output_address = client_addresses["output_address"] + self.stats_update_address = client_addresses.get( + "stats_update_address") + else: + host = parallel_config.data_parallel_master_ip + input_address = get_engine_client_zmq_addr(local_only, host) + output_address = get_engine_client_zmq_addr(local_only, host) # Create input and output sockets. self.input_socket = self.resources.input_socket = make_zmq_socket( self.ctx, input_address, zmq.ROUTER, bind=True) - self.resources.output_socket = make_zmq_socket( - self.ctx, output_address, zmq.constants.PULL) - # Start local engines. - if local_engine_count: - # In server mode, start_index and local_start_index will - # both be 0. - self.resources.local_engine_manager = CoreEngineProcManager( - EngineCoreProc.run_engine_core, - vllm_config=vllm_config, - executor_class=executor_class, - log_stats=log_stats, - input_address=input_address, - on_head_node=True, - local_engine_count=local_engine_count, - start_index=start_index, - local_start_index=local_start_index) + self.ctx, output_address, zmq.PULL) + + if client_addresses is None: + self._init_engines_direct(vllm_config, local_only, + local_start_index, input_address, + output_address, executor_class, + log_stats) + coordinator = self.resources.coordinator + if coordinator: + self.stats_update_address = ( + coordinator.get_stats_publish_address()) + + # Wait for ready messages from each engine on the input socket. + identities = set(e.identity for e in self.core_engines) + sync_input_socket = zmq.Socket.shadow(self.input_socket) + while identities: + if not sync_input_socket.poll(timeout=600_000): + raise TimeoutError("Timed out waiting for engines to send" + "initial message on input socket.") + identity, _ = sync_input_socket.recv_multipart() + identities.remove(identity) self.core_engine = self.core_engines[0] - - # Wait for engine core process(es) to start. - self._wait_for_engine_startup(output_address, parallel_config) - self.utility_results: dict[int, AnyFuture] = {} # Request objects which may contain pytorch-allocated tensors @@ -420,117 +419,66 @@ def __init__( if not success: self._finalizer() - @staticmethod - def _get_zmq_addresses(parallel_config: ParallelConfig, - spmd_mode: bool) -> tuple[str, str]: - """Returns (input_address, output_address).""" - dp_size = parallel_config.data_parallel_size + def _init_engines_direct(self, vllm_config: VllmConfig, local_only: bool, + local_start_index: int, input_address: str, + output_address: str, + executor_class: type[Executor], log_stats: bool): + """Self-contained client mode, launch engine and coordinator process + as needed.""" + + parallel_config = vllm_config.parallel_config local_engine_count = parallel_config.data_parallel_size_local + start_index = parallel_config.data_parallel_rank + host = parallel_config.data_parallel_master_ip - if local_engine_count == dp_size or spmd_mode: - input_address = get_open_zmq_ipc_path() - output_address = get_open_zmq_ipc_path() - else: - host = parallel_config.data_parallel_master_ip - input_port = parallel_config.data_parallel_rpc_port - output_port = get_open_port() - input_address = get_tcp_uri(host, input_port) - output_address = get_tcp_uri(host, output_port) - - return input_address, output_address - - def _wait_for_engine_startup(self, output_address: str, - parallel_config: ParallelConfig): - # Get a sync handle to the socket which can be sync or async. - sync_input_socket = zmq.Socket.shadow(self.input_socket) - - # Wait for engine core process(es) to send ready messages. - local_count = parallel_config.data_parallel_size_local - remote_count = len(self.core_engines) - local_count - # [local, remote] counts - conn_pending, start_pending = [local_count, remote_count], [0, 0] - - poller = zmq.Poller() - poller.register(sync_input_socket, zmq.POLLIN) - proc_manager = self.resources.local_engine_manager - if proc_manager is not None: - for sentinel in proc_manager.sentinels(): - poller.register(sentinel, zmq.POLLIN) - while any(conn_pending) or any(start_pending): - events = poller.poll(STARTUP_POLL_PERIOD_MS) - if not events: - if any(conn_pending): - logger.debug( - "Waiting for %d local, %d remote core engine proc(s) " - "to connect.", *conn_pending) - if any(start_pending): - logger.debug( - "Waiting for %d local, %d remote core engine proc(s) " - "to start.", *start_pending) - continue - if len(events) > 1 or events[0][0] != sync_input_socket: - # One of the local core processes exited. - finished = proc_manager.finished_procs( - ) if proc_manager else {} - raise RuntimeError("Engine core initialization failed. " - "See root cause above. " - f"Failed core proc(s): {finished}") - - # Receive HELLO and READY messages from the input socket. - eng_identity, ready_msg_bytes = sync_input_socket.recv_multipart() - eng_index = int.from_bytes(eng_identity, byteorder="little") - engine = next( - (e for e in self.core_engines if e.identity == eng_identity), - None) - if engine is None: - raise RuntimeError(f"Message from engine with unexpected data " - f"parallel rank: {eng_index}") - msg = msgspec.msgpack.decode(ready_msg_bytes) - status, local = msg["status"], msg["local"] - if local != engine.local: - raise RuntimeError(f"{status} message from " - f"{'local' if local else 'remote'} " - f"engine {eng_index}, expected it to be " - f"{'local' if engine.local else 'remote'}") - - if status == "HELLO" and engine.state == CoreEngineState.NEW: - - # Send init message with DP config info. - init_message = self.encoder.encode({ - "output_socket_address": output_address, - "parallel_config": { - "data_parallel_master_ip": - parallel_config.data_parallel_master_ip, - "data_parallel_master_port": - parallel_config.data_parallel_master_port, - "data_parallel_size": - parallel_config.data_parallel_size, - }, - }) - sync_input_socket.send_multipart((eng_identity, *init_message), - copy=False) - conn_pending[0 if local else 1] -= 1 - start_pending[0 if local else 1] += 1 - engine.state = CoreEngineState.CONNECTED - elif status == "READY" and (engine.state - == CoreEngineState.CONNECTED): - # Setup KV cache config with initialization state from - # engine core process. - - # TODO we'll receive one of these per engine in DP case. - # How should we aggregate? - self.vllm_config.cache_config.num_gpu_blocks = msg[ - "num_gpu_blocks"] - - start_pending[0 if local else 1] -= 1 - engine.state = CoreEngineState.READY - else: - raise RuntimeError(f"Unexpected {status} message for " - f"{'local' if local else 'remote'} engine " - f"{eng_index} in {engine.state} state.") + if len(self.core_engines) > 1: + self.resources.coordinator = DPCoordinator(parallel_config) + + handshake_address = get_engine_client_zmq_addr( + local_only, host, parallel_config.data_parallel_rpc_port) + + with zmq_socket_ctx(handshake_address, zmq.ROUTER, + bind=True) as handshake_socket: + + # Start local engines. + if local_engine_count: + # In server mode, start_index and local_start_index will + # both be 0. + self.resources.local_engine_manager = CoreEngineProcManager( + EngineCoreProc.run_engine_core, + vllm_config=vllm_config, + executor_class=executor_class, + log_stats=log_stats, + handshake_address=handshake_address, + on_head_node=True, + local_engine_count=local_engine_count, + start_index=start_index, + local_start_index=local_start_index) - logger.debug("%s from %s core engine process %s.", status, - "local" if local else "remote", eng_index) + # Wait for engine core process(es) to start. + self._wait_for_engine_startup(handshake_socket, input_address, + output_address) + + def _wait_for_engine_startup(self, handshake_socket: zmq.Socket, + input_address: str, output_address: str): + addresses: dict[str, Any] = { + "input_addresses": [input_address], + "output_addresses": [output_address], + } + + coordinator = self.resources.coordinator + if coordinator is not None: + addresses.update(coordinator.get_engine_socket_addresses()) + + wait_for_engine_startup( + handshake_socket, + addresses, + self.core_engines, + self.vllm_config.parallel_config, + self.vllm_config.cache_config, + self.resources.local_engine_manager, + coordinator.proc if coordinator else None, + ) def shutdown(self): # Terminate background resources. @@ -596,8 +544,8 @@ def process_outputs_socket(): try: shutdown_socket.bind(shutdown_path) poller = zmq.Poller() - poller.register(shutdown_socket) - poller.register(out_socket) + poller.register(shutdown_socket, zmq.POLLIN) + poller.register(out_socket, zmq.POLLIN) while True: socks = poller.poll() if not socks: @@ -659,7 +607,7 @@ def call_utility(self, method: str, *args) -> Any: future: Future[Any] = Future() self.utility_results[call_id] = future self._send_input(EngineCoreRequestType.UTILITY, - (call_id, method, args)) + (0, call_id, method, args)) return future.result() @@ -718,15 +666,21 @@ def save_sharded_state(self, class AsyncMPClient(MPClient): """Asyncio-compatible client for multi-proc EngineCore.""" - def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor], - log_stats: bool): + def __init__(self, + vllm_config: VllmConfig, + executor_class: type[Executor], + log_stats: bool, + client_addresses: Optional[dict[str, str]] = None, + client_index: int = 0): super().__init__( asyncio_mode=True, vllm_config=vllm_config, executor_class=executor_class, log_stats=log_stats, + client_addresses=client_addresses, ) + self.client_index = client_index self.outputs_queue = asyncio.Queue[Union[EngineCoreOutputs, Exception]]() try: @@ -842,12 +796,13 @@ async def _call_utility_async(self, method: str, *args, future = asyncio.get_running_loop().create_future() self.utility_results[call_id] = future message = (EngineCoreRequestType.UTILITY.value, *self.encoder.encode( - (call_id, method, args))) + (self.client_index, call_id, method, args))) await self._send_input_message(message, engine, args) self._ensure_output_queue_task() return await future async def add_request_async(self, request: EngineCoreRequest) -> None: + request.client_index = self.client_index await self._send_input(EngineCoreRequestType.ADD, request) self._ensure_output_queue_task() @@ -906,17 +861,109 @@ class DPAsyncMPClient(AsyncMPClient): """Asyncio-compatible client for multi-proc, multi-engine (data parallel) EngineCore.""" - def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor], - log_stats: bool): + def __init__(self, + vllm_config: VllmConfig, + executor_class: type[Executor], + log_stats: bool, + client_addresses: Optional[dict[str, str]] = None, + client_index: int = 0): self.current_wave = 0 self.engines_running = False + # To route aborts to the correct engine. self.reqs_in_flight: dict[str, CoreEngine] = {} - super().__init__(vllm_config, executor_class, log_stats) + super().__init__(vllm_config, executor_class, log_stats, + client_addresses, client_index) assert len(self.core_engines) > 1 + self.lb_engines: Optional[list[int]] = None + self.lb_index = self.client_index + + self.first_req_sock_addr = get_open_zmq_inproc_path() + self.first_req_send_socket = make_zmq_socket(self.ctx, + self.first_req_sock_addr, + zmq.PAIR, + bind=True) + try: + # If we are running in an asyncio event loop, start the stats task. + # Otherwise, it will be started lazily. + asyncio.get_running_loop() + self._ensure_stats_update_task() + except RuntimeError: + pass + + def _ensure_stats_update_task(self): + resources = self.resources + if resources.stats_update_task is not None: + return + + assert self.stats_update_address is not None + + async def run_engine_stats_update_task(): + with make_zmq_socket(self.ctx, self.stats_update_address, + zmq.XSUB) as socket, make_zmq_socket( + self.ctx, + self.first_req_sock_addr, + zmq.PAIR, + bind=False) as first_req_rcv_socket: + + # TODO CHECK WHY THIS SUB DOESN'T SEEM TO WORK + # Send subscription message. + await socket.send(b'\x01') + + poller = zmq.asyncio.Poller() + poller.register(socket, zmq.POLLIN) + poller.register(first_req_rcv_socket, zmq.POLLIN) + + while True: + events = await poller.poll() + if not self.engines_running and len(events) == 2 or ( + events[0][0] == first_req_rcv_socket): + # Send a message to notify the coordinator that + # we're sending a request while the engines are + # paused, so that it can wake the others up + # (to run dummy EP loop). + self.engines_running = True + buf = first_req_rcv_socket.recv( + flags=zmq.NOBLOCK).result() + target_eng_index = int.from_bytes(buf, "little") + msg = msgspec.msgpack.encode( + (target_eng_index, self.current_wave)) + await socket.send(msg) + + buf = None + while True: + # Drain all stats events (we only care about latest). + future: asyncio.Future[bytes] = socket.recv( + flags=zmq.NOBLOCK) + if isinstance(future.exception(), zmq.Again): + break + buf = future.result() + if buf is None: + continue + + # Update local load-balancing state. + engines, wave, running = msgspec.msgpack.decode(buf) + self.current_wave = wave + self.engines_running = running + if self.lb_engines != engines: + self.lb_index = self.client_index + self.lb_engines = engines + + resources.stats_update_task = asyncio.create_task( + run_engine_stats_update_task()) + + def get_core_engine_for_request(self) -> CoreEngine: + index = self.lb_index + if self.lb_engines: + eng_index = self.lb_engines[index % len(self.lb_engines)] + else: + eng_index = index % len(self.core_engines) + self.lb_index = index + 1 + return self.core_engines[eng_index] + async def call_utility_async(self, method: str, *args) -> Any: # Only the result from the first engine is returned. return (await asyncio.gather(*[ @@ -925,62 +972,30 @@ async def call_utility_async(self, method: str, *args) -> Any: ]))[0] async def add_request_async(self, request: EngineCoreRequest) -> None: + self._ensure_stats_update_task() + request.current_wave = self.current_wave + request.client_index = self.client_index chosen_engine = self.get_core_engine_for_request() self.reqs_in_flight[request.request_id] = chosen_engine - chosen_engine.num_reqs_in_flight += 1 to_await = self._send_input(EngineCoreRequestType.ADD, request, chosen_engine) if not self.engines_running: - # Send request to chosen engine and dp start loop - # control message to all other engines. - self.engines_running = True - to_await = asyncio.gather( - to_await, # type: ignore[assignment] - *self._start_wave_coros(exclude_index=chosen_engine.index)) + # Notify coordinator that we're sending a request + await self.first_req_send_socket.send(chosen_engine.identity) await to_await self._ensure_output_queue_task() - def get_core_engine_for_request(self) -> CoreEngine: - return min(self.core_engines, key=lambda e: e.num_reqs_in_flight) - @staticmethod async def process_engine_outputs(self: "DPAsyncMPClient", outputs: EngineCoreOutputs): - if self.reqs_in_flight: - for req_id in outputs.finished_requests or (): - if engine := self.reqs_in_flight.pop(req_id, None): - engine.num_reqs_in_flight -= 1 - - if outputs.wave_complete is not None: - # Current wave is complete, move to next wave number - # and mark engines as paused. - if self.current_wave <= outputs.wave_complete: - self.current_wave = outputs.wave_complete + 1 - self.engines_running = False - - elif outputs.start_wave is not None and ( - outputs.start_wave > self.current_wave or - (outputs.start_wave == self.current_wave - and not self.engines_running)): - # Engine received request for a non-current wave so we must ensure - # that other engines progress to the next wave. - self.current_wave = outputs.start_wave - self.engines_running = True - await asyncio.gather(*self._start_wave_coros( - exclude_index=outputs.engine_index)) - - def _start_wave_coros(self, exclude_index: int) -> list[Awaitable[None]]: - logger.debug("Sending start DP wave %d.", self.current_wave) - return [ - self._send_input(EngineCoreRequestType.START_DP_WAVE, - self.current_wave, engine) - for engine in self.core_engines if engine.index != exclude_index - ] + if outputs.finished_requests and self.reqs_in_flight: + for req_id in outputs.finished_requests: + self.reqs_in_flight.pop(req_id, None) async def abort_requests_async(self, request_ids: list[str]) -> None: if not request_ids: diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index 9109bdcf42f..17575b123a1 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -35,7 +35,7 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): ... @abstractmethod - def record(self, scheduler_stats: SchedulerStats, + def record(self, scheduler_stats: Optional[SchedulerStats], iteration_stats: Optional[IterationStats]): ... @@ -78,20 +78,22 @@ def _get_throughput(self, tracked_stats: list[int], now: float) -> float: # Compute summary metrics for tracked stats return float(np.sum(tracked_stats) / (now - self.last_log_time)) - def record(self, scheduler_stats: SchedulerStats, + def record(self, scheduler_stats: Optional[SchedulerStats], iteration_stats: Optional[IterationStats]): """Log Stats to standard output.""" if iteration_stats: self._track_iteration_stats(iteration_stats) - self.prefix_caching_metrics.observe(scheduler_stats.prefix_cache_stats) + if scheduler_stats is not None: + self.prefix_caching_metrics.observe( + scheduler_stats.prefix_cache_stats) - if scheduler_stats.spec_decoding_stats is not None: - self.spec_decoding_logging.observe( - scheduler_stats.spec_decoding_stats) + if scheduler_stats.spec_decoding_stats is not None: + self.spec_decoding_logging.observe( + scheduler_stats.spec_decoding_stats) - self.last_scheduler_stats = scheduler_stats + self.last_scheduler_stats = scheduler_stats def log(self): now = time.monotonic() @@ -373,22 +375,23 @@ def log_metrics_info(self, type: str, config_obj: SupportsMetricsInfo): labelnames=metrics_info.keys()).labels(**metrics_info) info_gauge.set(1) - def record(self, scheduler_stats: SchedulerStats, + def record(self, scheduler_stats: Optional[SchedulerStats], iteration_stats: Optional[IterationStats]): """Log to prometheus.""" - self.gauge_scheduler_running.set(scheduler_stats.num_running_reqs) - self.gauge_scheduler_waiting.set(scheduler_stats.num_waiting_reqs) + if scheduler_stats is not None: + self.gauge_scheduler_running.set(scheduler_stats.num_running_reqs) + self.gauge_scheduler_waiting.set(scheduler_stats.num_waiting_reqs) - self.gauge_gpu_cache_usage.set(scheduler_stats.gpu_cache_usage) + self.gauge_gpu_cache_usage.set(scheduler_stats.gpu_cache_usage) - self.counter_gpu_prefix_cache_queries.inc( - scheduler_stats.prefix_cache_stats.queries) - self.counter_gpu_prefix_cache_hits.inc( - scheduler_stats.prefix_cache_stats.hits) + self.counter_gpu_prefix_cache_queries.inc( + scheduler_stats.prefix_cache_stats.queries) + self.counter_gpu_prefix_cache_hits.inc( + scheduler_stats.prefix_cache_stats.hits) - if scheduler_stats.spec_decoding_stats is not None: - self.spec_decoding_prom.observe( - scheduler_stats.spec_decoding_stats) + if scheduler_stats.spec_decoding_stats is not None: + self.spec_decoding_prom.observe( + scheduler_stats.spec_decoding_stats) if iteration_stats is None: return diff --git a/vllm/v1/request.py b/vllm/v1/request.py index fde366d61c7..52bc20567eb 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -20,18 +20,19 @@ class Request: def __init__( self, request_id: str, + client_index: int, prompt_token_ids: list[int], multi_modal_inputs: Optional[list[MultiModalKwargs]], multi_modal_hashes: Optional[list[str]], multi_modal_placeholders: Optional[list[PlaceholderRange]], sampling_params: SamplingParams, eos_token_id: Optional[int], - arrival_time: float, lora_request: Optional["LoRARequest"] = None, structured_output_request: Optional["StructuredOutputRequest"] = None, cache_salt: Optional[str] = None, ) -> None: self.request_id = request_id + self.client_index = client_index self.sampling_params = sampling_params # Because of LoRA, the eos token id can be different for each request. self.eos_token_id = eos_token_id @@ -81,13 +82,13 @@ def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": return cls( request_id=request.request_id, + client_index=request.client_index, prompt_token_ids=request.prompt_token_ids, multi_modal_inputs=request.mm_inputs, multi_modal_hashes=request.mm_hashes, multi_modal_placeholders=request.mm_placeholders, sampling_params=request.sampling_params, eos_token_id=request.eos_token_id, - arrival_time=request.arrival_time, lora_request=request.lora_request, structured_output_request=StructuredOutputRequest( sampling_params=request.sampling_params), diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 0758747a83c..01d32a0c5db 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -1,22 +1,25 @@ # SPDX-License-Identifier: Apache-2.0 -import os import time import weakref from collections import defaultdict from collections.abc import Sequence +from enum import Enum, auto from multiprocessing import Process, connection -from typing import (TYPE_CHECKING, Callable, Generic, Optional, TypeVar, Union, - overload) +from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar, + Union, overload) +import msgspec import torch +import zmq -from vllm.config import VllmConfig +from vllm.config import CacheConfig, ParallelConfig, VllmConfig from vllm.logger import init_logger from vllm.model_executor.models.utils import extract_layer_index from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) -from vllm.utils import get_mp_context, kill_process_tree +from vllm.utils import (get_mp_context, get_open_port, get_open_zmq_ipc_path, + get_tcp_uri, kill_process_tree) from vllm.v1.executor.abstract import Executor if TYPE_CHECKING: @@ -26,6 +29,8 @@ T = TypeVar("T") +STARTUP_POLL_PERIOD_MS = 10000 + class ConstantList(Generic[T], Sequence): @@ -95,6 +100,13 @@ def __repr__(self): return f"ConstantList({self._x})" +def get_engine_client_zmq_addr(local_only: bool, + host: str, + port: int = 0) -> str: + return get_open_zmq_ipc_path() if local_only else (get_tcp_uri( + host, port or get_open_port())) + + class CoreEngineProcManager: """ Utility class to handle creation, readiness, and shutdown @@ -109,7 +121,7 @@ def __init__( local_start_index: int, vllm_config: VllmConfig, on_head_node: bool, - input_address: str, + handshake_address: str, executor_class: type[Executor], log_stats: bool, ): @@ -117,7 +129,7 @@ def __init__( common_kwargs = { "vllm_config": vllm_config, "on_head_node": on_head_node, - "input_address": input_address, + "handshake_address": handshake_address, "executor_class": executor_class, "log_stats": log_stats, } @@ -135,8 +147,7 @@ def __init__( "local_dp_rank": local_index, })) - self._finalizer = weakref.finalize(self, shutdown, self.processes, - input_address) + self._finalizer = weakref.finalize(self, shutdown, self.processes) try: for proc in self.processes: proc.start() @@ -164,9 +175,125 @@ def finished_procs(self) -> dict[str, int]: } +class CoreEngineState(Enum): + NEW = auto() + CONNECTED = auto() + READY = auto() + + +class CoreEngine: + """One per data parallel rank.""" + + def __init__(self, index: int = 0, local: bool = True): + self.local = local + self.index = index + self.identity = index.to_bytes(2, "little") + + self.state = CoreEngineState.NEW + + +def wait_for_engine_startup( + handshake_socket: zmq.Socket, + addresses: dict[str, Any], + core_engines: list[CoreEngine], + parallel_config: ParallelConfig, + cache_config: CacheConfig, + proc_manager: Optional[CoreEngineProcManager], + coord_process: Optional[Process], +): + + # Wait for engine core process(es) to send ready messages. + local_count = parallel_config.data_parallel_size_local + remote_count = len(core_engines) - local_count + # [local, remote] counts + conn_pending, start_pending = [local_count, remote_count], [0, 0] + poller = zmq.Poller() + poller.register(handshake_socket, zmq.POLLIN) + + if proc_manager is not None: + for sentinel in proc_manager.sentinels(): + poller.register(sentinel, zmq.POLLIN) + if coord_process is not None: + poller.register(coord_process.sentinel, zmq.POLLIN) + while any(conn_pending) or any(start_pending): + events = poller.poll(STARTUP_POLL_PERIOD_MS) + if not events: + if any(conn_pending): + logger.debug( + "Waiting for %d local, %d remote core engine proc(s) " + "to connect.", *conn_pending) + if any(start_pending): + logger.debug( + "Waiting for %d local, %d remote core engine proc(s) " + "to start.", *start_pending) + continue + if len(events) > 1 or events[0][0] != handshake_socket: + # One of the local core processes exited. + finished = proc_manager.finished_procs() if proc_manager else {} + if coord_process is not None and coord_process.exitcode is not None: + finished[coord_process.name] = coord_process.exitcode + raise RuntimeError("Engine core initialization failed. " + "See root cause above. " + f"Failed core proc(s): {finished}") + + # Receive HELLO and READY messages from the input socket. + eng_identity, ready_msg_bytes = handshake_socket.recv_multipart() + eng_index = int.from_bytes(eng_identity, "little") + engine = next((e for e in core_engines if e.identity == eng_identity), + None) + if engine is None: + raise RuntimeError(f"Message from engine with unexpected data " + f"parallel rank: {eng_index}") + msg = msgspec.msgpack.decode(ready_msg_bytes) + status, local = msg["status"], msg["local"] + if local != engine.local: + raise RuntimeError(f"{status} message from " + f"{'local' if local else 'remote'} " + f"engine {eng_index}, expected it to be " + f"{'local' if engine.local else 'remote'}") + + if status == "HELLO" and engine.state == CoreEngineState.NEW: + + # Send init message with DP config info. + init_message = msgspec.msgpack.encode({ + "addresses": addresses, + "parallel_config": { + "data_parallel_master_ip": + parallel_config.data_parallel_master_ip, + "data_parallel_master_port": + parallel_config.data_parallel_master_port, + "data_parallel_size": parallel_config.data_parallel_size, + }, + }) + handshake_socket.send_multipart((eng_identity, init_message), + copy=False) + conn_pending[0 if local else 1] -= 1 + start_pending[0 if local else 1] += 1 + engine.state = CoreEngineState.CONNECTED + elif status == "READY" and (engine.state == CoreEngineState.CONNECTED): + # Setup KV cache config with initialization state from + # engine core process. + + # TODO we'll receive one of these per engine in DP case. + # How should we aggregate? + # Also in multi-API server case, this runs in the bootstrap process + # and won't currently make its way into the published metrics. + cache_config.num_gpu_blocks = msg["num_gpu_blocks"] + + start_pending[0 if local else 1] -= 1 + engine.state = CoreEngineState.READY + else: + raise RuntimeError(f"Unexpected {status} message for " + f"{'local' if local else 'remote'} engine " + f"{eng_index} in {engine.state} state.") + + logger.debug("%s from %s core engine process %s.", status, + "local" if local else "remote", eng_index) + + # Note(rob): shutdown function cannot be a bound method, -# else the gc cannot collect the objedecoupct. -def shutdown(procs: list[Process], input_address: str): +# else the gc cannot collect the object. +def shutdown(procs: list[Process]): # Shutdown the process. for proc in procs: if proc.is_alive(): @@ -185,12 +312,6 @@ def shutdown(procs: list[Process], input_address: str): if proc.is_alive() and (pid := proc.pid) is not None: kill_process_tree(pid) - # Remove zmq ipc socket files. - if input_address.startswith("ipc://"): - socket_file = input_address[len("ipc://"):] - if os and os.path.exists(socket_file): - os.remove(socket_file) - def bind_kv_cache( kv_caches: dict[str, torch.Tensor], From 742b53205826cf52609f38c1306e8a71aaed8b2e Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 1 May 2025 12:07:56 -0700 Subject: [PATCH 12/22] Fix engine init num_gpu_blocks logging Avoid exception but still needs more work to be functional with multiple api server procs. Signed-off-by: Nick Hill --- vllm/v1/engine/async_llm.py | 5 +++-- vllm/v1/metrics/loggers.py | 9 +++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 31b844e3a68..0466c8702c0 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -124,8 +124,9 @@ def __init__( client_addresses=client_addresses, client_index=client_index, ) - for stat_logger in self.stat_loggers[0]: - stat_logger.log_engine_initialized() + if self.stat_loggers: + for stat_logger in self.stat_loggers[0]: + stat_logger.log_engine_initialized() self.output_handler: Optional[asyncio.Task] = None try: # Start output handler eagerly if we are in the asyncio eventloop. diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index 17575b123a1..f563810125a 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -135,10 +135,11 @@ def log(self): self.spec_decoding_logging.log(log_fn=log_fn) def log_engine_initialized(self): - logger.info( - "vllm cache_config_info with initialization " \ - "after num_gpu_blocks is: %d", - self.vllm_config.cache_config.num_gpu_blocks) + if self.vllm_config.cache_config.num_gpu_blocks: + logger.info( + "Engine %03d: vllm cache_config_info with initialization " + "after num_gpu_blocks is: %d", self.engine_index, + self.vllm_config.cache_config.num_gpu_blocks) class PrometheusStatLogger(StatLoggerBase): From 6340c87becd942646b56afe6f498a8e54fe91965 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Mon, 5 May 2025 14:26:18 -0700 Subject: [PATCH 13/22] Improve load balancing Signed-off-by: Nick Hill --- vllm/v1/engine/coordinator.py | 5 ++++- vllm/v1/engine/core_client.py | 37 +++++++++++++++++++++++------------ 2 files changed, 28 insertions(+), 14 deletions(-) diff --git a/vllm/v1/engine/coordinator.py b/vllm/v1/engine/coordinator.py index d0421e00ade..ef5402196c6 100644 --- a/vllm/v1/engine/coordinator.py +++ b/vllm/v1/engine/coordinator.py @@ -130,7 +130,7 @@ def process_input_socket(self, front_publish_address: str, wait_for = 100 if self.stats_changed else 3000 events = poller.poll(timeout=max(0, wait_for - elapsed)) if not events: - engine_list = self._get_engine_list() + engine_list = self._get_engine_counts() to_publish = (engine_list, self.current_wave, self.engines_running) msg = msgspec.msgpack.encode(to_publish) @@ -195,6 +195,9 @@ def _send_start_wave(socket: zmq.Socket, wave: int, socket.send_multipart( (EngineCoreRequestType.START_DP_WAVE.value, wave_encoded)) + def _get_engine_counts(self) -> list[list[int]]: + return [e.request_counts for e in self.engines] + def _get_engine_list(self) -> Optional[list[int]]: shortlist: list[int] = [] min_counts = [sys.maxsize, sys.maxsize] diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index d019279b56d..ba5614e8ffe 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -2,6 +2,7 @@ import asyncio import contextlib import queue +import sys import uuid import weakref from abc import ABC, abstractmethod @@ -878,8 +879,8 @@ def __init__(self, assert len(self.core_engines) > 1 - self.lb_engines: Optional[list[int]] = None - self.lb_index = self.client_index + # List of [waiting, running] pair per engine. + self.lb_engines: list[list[int]] = [] self.first_req_sock_addr = get_open_zmq_inproc_path() self.first_req_send_socket = make_zmq_socket(self.ctx, @@ -908,8 +909,6 @@ async def run_engine_stats_update_task(): self.first_req_sock_addr, zmq.PAIR, bind=False) as first_req_rcv_socket: - - # TODO CHECK WHY THIS SUB DOESN'T SEEM TO WORK # Send subscription message. await socket.send(b'\x01') @@ -945,23 +944,35 @@ async def run_engine_stats_update_task(): continue # Update local load-balancing state. - engines, wave, running = msgspec.msgpack.decode(buf) + counts, wave, running = msgspec.msgpack.decode(buf) self.current_wave = wave self.engines_running = running - if self.lb_engines != engines: - self.lb_index = self.client_index - self.lb_engines = engines + self.lb_engines = counts resources.stats_update_task = asyncio.create_task( run_engine_stats_update_task()) def get_core_engine_for_request(self) -> CoreEngine: - index = self.lb_index - if self.lb_engines: - eng_index = self.lb_engines[index % len(self.lb_engines)] + if not self.lb_engines: + return self.core_engines[0] + # TODO use P2C alg for larger DP sizes + num_engines = len(self.lb_engines) + min_counts = [sys.maxsize, sys.maxsize] + eng_index = 0 + for i in range(num_engines): + # Start from client_index for to help with balancing when + # engines are empty. + idx = (self.client_index + i) % num_engines + counts = self.lb_engines[idx] + if counts < min_counts: + min_counts = counts + eng_index = idx + # Adjust local counts for better balancing between stats updates + # from the coordinator (these are overwritten 10x per second). + if min_counts[0]: + min_counts[0] += 1 else: - eng_index = index % len(self.core_engines) - self.lb_index = index + 1 + min_counts[1] += 1 return self.core_engines[eng_index] async def call_utility_async(self, method: str, *args) -> Any: From 877f1959d5fefefcd6b14e6aa373eed8de8d4eb3 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Mon, 5 May 2025 17:10:25 -0700 Subject: [PATCH 14/22] small fixes Signed-off-by: Nick Hill --- vllm/v1/engine/core.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index b8dbb1f261e..0a1f2d02fd8 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -392,11 +392,11 @@ def __init__( self.output_thread.start() @staticmethod - def startup_handshake(input_socket: zmq.Socket, on_head_node: bool, + def startup_handshake(handshake_socket: zmq.Socket, on_head_node: bool, parallel_config: ParallelConfig) -> dict[str, Any]: # Send registration message. - input_socket.send( + handshake_socket.send( msgspec.msgpack.encode({ "status": "HELLO", "local": on_head_node, @@ -404,11 +404,11 @@ def startup_handshake(input_socket: zmq.Socket, on_head_node: bool, # Receive initialization message. logger.info("Waiting for init message from front-end.") - if not input_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60 * 1000): + if not handshake_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60_000): raise RuntimeError("Did not receive response from front-end " f"process within {HANDSHAKE_TIMEOUT_MINS} " f"minutes") - init_bytes = input_socket.recv() + init_bytes = handshake_socket.recv() init_message = msgspec.msgpack.decode(init_bytes) logger.debug("Received init message: %s", init_message) @@ -617,8 +617,8 @@ def process_input_sockets(self, input_addresses: list[str], # back to us. input_socket.send(b'') poller.register(input_socket, zmq.POLLIN) - if coord_socket is not None: - poller.register(coord_socket, zmq.POLLIN) + if coord_socket is not None: + poller.register(coord_socket, zmq.POLLIN) while True: for input_socket, _ in poller.poll(): From 42c30bf4ba68bc511709d8529c0db4d7c419a1d5 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Mon, 12 May 2025 09:33:08 -0700 Subject: [PATCH 15/22] Fix test_startup_failure Signed-off-by: Nick Hill --- tests/v1/engine/test_engine_core_client.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index fd8d1fd7ff4..452fe1e37e2 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -18,9 +18,10 @@ from vllm.usage.usage_lib import UsageContext from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.core import EngineCore -from vllm.v1.engine.core_client import (AsyncMPClient, CoreEngine, - EngineCoreClient, SyncMPClient) +from vllm.v1.engine.core_client import (AsyncMPClient, EngineCoreClient, + SyncMPClient) from vllm.v1.executor.abstract import Executor +from vllm.v1.utils import CoreEngineProcManager from ...distributed.conftest import MockSubscriber from ...utils import create_new_process_for_each_test @@ -348,13 +349,13 @@ def test_startup_failure(monkeypatch: pytest.MonkeyPatch): # Monkey-patch to extract core process pid while it's starting. core_proc_pid = [None] - ce_ctor = CoreEngine.__init__ + cepm_ctor = CoreEngineProcManager.__init__ - def patched_ce_ctor(self, *args, **kwargs): - ce_ctor(self, *args, **kwargs) - core_proc_pid[0] = self.proc_handle.proc.pid + def patched_cepm_ctor(self: CoreEngineProcManager, *args, **kwargs): + cepm_ctor(self, *args, **kwargs) + core_proc_pid[0] = self.processes[0].pid - m.setattr(CoreEngine, "__init__", patched_ce_ctor) + m.setattr(CoreEngineProcManager, "__init__", patched_cepm_ctor) t = time.time() engine_args = EngineArgs(model=MODEL_NAME) From 3904d10662f3dc1cdd301125743ac4061992dd15 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Mon, 12 May 2025 10:00:25 -0700 Subject: [PATCH 16/22] Fix mock config related test failure Signed-off-by: Nick Hill --- tests/async_engine/test_async_llm_engine.py | 2 +- vllm/config.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/async_engine/test_async_llm_engine.py b/tests/async_engine/test_async_llm_engine.py index 48e2e31e5db..b6f44871497 100644 --- a/tests/async_engine/test_async_llm_engine.py +++ b/tests/async_engine/test_async_llm_engine.py @@ -41,7 +41,7 @@ def __init__(self): self.abort_request_calls = 0 self.request_id = None # Ugly, remove dependency when possible - self.parallel_config = ParallelConfig(1, 1, False) + self.parallel_config = ParallelConfig() self.model_config = MockModelConfig() async def step_async(self, virtual_engine): diff --git a/vllm/config.py b/vllm/config.py index d83232e2e1b..ff87ae9092f 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1759,7 +1759,8 @@ def __post_init__(self) -> None: if self.data_parallel_size_local > self.data_parallel_size: raise ValueError( - "data_parallel_size_local must be <= data_parallel_size") + f"data_parallel_size_local ({self.data_parallel_size_local}) " + f"must be <= data_parallel_size ({self.data_parallel_size})") if self.data_parallel_size > 1 or self.data_parallel_size_local == 0: # Data parallel was specified in the engine args. From 1bf3a6318c5f7e6033631213c9ee665b126335d7 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Fri, 16 May 2025 12:51:07 -0700 Subject: [PATCH 17/22] [Misc][DP] Fix AsyncLLM metrics for multi-API server deployments Signed-off-by: kouroshhakha Co-authored-by: kouroshhakha --- vllm/entrypoints/cli/serve.py | 5 +- vllm/entrypoints/openai/api_server.py | 22 +++----- vllm/v1/engine/async_llm.py | 3 ++ vllm/v1/metrics/loggers.py | 31 ++++++----- vllm/v1/metrics/prometheus.py | 74 +++++++++++++++++++++++++++ 5 files changed, 106 insertions(+), 29 deletions(-) create mode 100644 vllm/v1/metrics/prometheus.py diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index 5c2f94e08ed..4fe1d3a780c 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -26,6 +26,7 @@ from vllm.v1.engine.core import EngineCoreProc from vllm.v1.engine.core_client import CoreEngineProcManager from vllm.v1.executor.abstract import Executor +from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus from vllm.v1.utils import (CoreEngine, get_engine_client_zmq_addr, wait_for_engine_startup) @@ -162,7 +163,9 @@ def run_multi_api_server(args: argparse.Namespace): assert not args.headless num_api_servers = args.api_server_count - # assert num_api_servers > 1 + assert num_api_servers > 1 + + setup_multiprocess_prometheus() listen_address, sock = setup_server(args) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 97936048dc2..7261bdfa03c 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -24,6 +24,8 @@ from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response, StreamingResponse +from prometheus_client import make_asgi_app +from prometheus_fastapi_instrumentator import Instrumentator from starlette.concurrency import iterate_in_threadpool from starlette.datastructures import State from starlette.routing import Mount @@ -97,6 +99,7 @@ from vllm.usage.usage_lib import UsageContext from vllm.utils import (Device, FlexibleArgumentParser, get_open_zmq_ipc_path, is_valid_ipv6_address, set_ulimit) +from vllm.v1.metrics.prometheus import get_prometheus_registry from vllm.version import __version__ as VLLM_VERSION TIMEOUT_KEEP_ALIVE = 5 # seconds @@ -323,22 +326,9 @@ async def validate_json_request(raw_request: Request): def mount_metrics(app: FastAPI): - # Lazy import for prometheus multiprocessing. - # We need to set PROMETHEUS_MULTIPROC_DIR environment variable - # before prometheus_client is imported. - # See https://prometheus.github.io/client_python/multiprocess/ - from prometheus_client import (REGISTRY, CollectorRegistry, make_asgi_app, - multiprocess) - from prometheus_fastapi_instrumentator import Instrumentator - - registry = REGISTRY - - prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None) - if prometheus_multiproc_dir_path is not None: - logger.debug("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR", - prometheus_multiproc_dir_path) - registry = CollectorRegistry() - multiprocess.MultiProcessCollector(registry) + """Mount prometheus metrics to a FastAPI app.""" + + registry = get_prometheus_registry() Instrumentator( excluded_handlers=[ diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index ab6200e6f18..7026793befb 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -34,6 +34,7 @@ from vllm.v1.executor.abstract import Executor from vllm.v1.metrics.loggers import (StatLoggerBase, StatLoggerFactory, setup_default_loggers) +from vllm.v1.metrics.prometheus import shutdown_prometheus from vllm.v1.metrics.stats import IterationStats, SchedulerStats logger = init_logger(__name__) @@ -198,6 +199,8 @@ def __del__(self): def shutdown(self): """Shutdown, cleaning up the background proc and IPC.""" + shutdown_prometheus() + if engine_core := getattr(self, "engine_core", None): engine_core.shutdown() diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index 3bc6e291342..48e03f4e99d 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -12,13 +12,12 @@ from vllm.logger import init_logger from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics from vllm.v1.engine import FinishReason +from vllm.v1.metrics.prometheus import unregister_vllm_metrics from vllm.v1.metrics.stats import IterationStats, SchedulerStats from vllm.v1.spec_decode.metrics import SpecDecodingLogging, SpecDecodingProm logger = init_logger(__name__) -_LOCAL_LOGGING_INTERVAL_SEC = 5.0 - StatLoggerFactory = Callable[[VllmConfig, int], "StatLoggerBase"] @@ -143,7 +142,8 @@ def log_engine_initialized(self): class PrometheusStatLogger(StatLoggerBase): def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): - self._unregister_vllm_metrics() + + unregister_vllm_metrics() self.vllm_config = vllm_config self.engine_index = engine_index # Use this flag to hide metrics that were deprecated in @@ -168,11 +168,13 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): self.gauge_scheduler_running = prometheus_client.Gauge( name="vllm:num_requests_running", documentation="Number of requests in model execution batches.", + multiprocess_mode="mostrecent", labelnames=labelnames).labels(*labelvalues) self.gauge_scheduler_waiting = prometheus_client.Gauge( name="vllm:num_requests_waiting", documentation="Number of requests waiting to be processed.", + multiprocess_mode="mostrecent", labelnames=labelnames).labels(*labelvalues) # @@ -181,6 +183,7 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): self.gauge_gpu_cache_usage = prometheus_client.Gauge( name="vllm:gpu_cache_usage_perc", documentation="GPU KV-cache usage. 1 means 100 percent usage.", + multiprocess_mode="mostrecent", labelnames=labelnames).labels(*labelvalues) self.counter_gpu_prefix_cache_queries = prometheus_client.Counter( @@ -241,6 +244,9 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): buckets=build_1_2_5_buckets(max_model_len), labelnames=labelnames).labels(*labelvalues) + # TODO: This metric might be incorrect in case of using multiple + # api_server counts which uses prometheus mp. + # See: https://github.com/vllm-project/vllm/pull/18053 self.histogram_iteration_tokens = \ prometheus_client.Histogram( name="vllm:iteration_tokens_total", @@ -339,6 +345,9 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): # # LoRA metrics # + + # TODO: This metric might be incorrect in case of using multiple + # api_server counts which uses prometheus mp. self.gauge_lora_info: Optional[prometheus_client.Gauge] = None if vllm_config.lora_config is not None: self.labelname_max_lora = "max_lora" @@ -349,13 +358,16 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): prometheus_client.Gauge( name="vllm:lora_requests_info", documentation="Running stats on lora requests.", + multiprocess_mode="sum", labelnames=[ self.labelname_max_lora, self.labelname_waiting_lora_adapters, self.labelname_running_lora_adapters, - ]) + ], + ) def log_metrics_info(self, type: str, config_obj: SupportsMetricsInfo): + metrics_info = config_obj.metrics_info() metrics_info["engine"] = self.engine_index @@ -371,7 +383,9 @@ def log_metrics_info(self, type: str, config_obj: SupportsMetricsInfo): info_gauge = prometheus_client.Gauge( name=name, documentation=documentation, - labelnames=metrics_info.keys()).labels(**metrics_info) + multiprocess_mode="mostrecent", + labelnames=metrics_info.keys(), + ).labels(**metrics_info) info_gauge.set(1) def record(self, scheduler_stats: Optional[SchedulerStats], @@ -445,13 +459,6 @@ def record(self, scheduler_stats: Optional[SchedulerStats], self.gauge_lora_info.labels(**lora_info_labels)\ .set_to_current_time() - @staticmethod - def _unregister_vllm_metrics(): - # Unregister any existing vLLM collectors (for CI/CD - for collector in list(prometheus_client.REGISTRY._collector_to_names): - if hasattr(collector, "_name") and "vllm" in collector._name: - prometheus_client.REGISTRY.unregister(collector) - def log_engine_initialized(self): self.log_metrics_info("cache_config", self.vllm_config.cache_config) diff --git a/vllm/v1/metrics/prometheus.py b/vllm/v1/metrics/prometheus.py new file mode 100644 index 00000000000..c958d7cd31f --- /dev/null +++ b/vllm/v1/metrics/prometheus.py @@ -0,0 +1,74 @@ +# SPDX-License-Identifier: Apache-2.0 + +import os +import tempfile +from typing import Optional + +from prometheus_client import REGISTRY, CollectorRegistry, multiprocess + +from vllm.logger import init_logger + +logger = init_logger(__name__) + +# Global temporary directory for prometheus multiprocessing +_prometheus_multiproc_dir: Optional[tempfile.TemporaryDirectory] = None + + +def setup_multiprocess_prometheus(): + """Set up prometheus multiprocessing directory if not already configured. + + """ + global _prometheus_multiproc_dir + + if "PROMETHEUS_MULTIPROC_DIR" not in os.environ: + # Make TemporaryDirectory for prometheus multiprocessing + # Note: global TemporaryDirectory will be automatically + # cleaned up upon exit. + _prometheus_multiproc_dir = tempfile.TemporaryDirectory() + os.environ["PROMETHEUS_MULTIPROC_DIR"] = _prometheus_multiproc_dir.name + logger.debug("Created PROMETHEUS_MULTIPROC_DIR at %s", + _prometheus_multiproc_dir.name) + else: + logger.warning("Found PROMETHEUS_MULTIPROC_DIR was set by user. " + "This directory must be wiped between vLLM runs or " + "you will find inaccurate metrics. Unset the variable " + "and vLLM will properly handle cleanup.") + + +def get_prometheus_registry(): + """Get the appropriate prometheus registry based on multiprocessing + configuration. + + Returns: + Registry: A prometheus registry + """ + if os.getenv("PROMETHEUS_MULTIPROC_DIR") is not None: + logger.debug("Using multiprocess registry for prometheus metrics") + registry = CollectorRegistry() + multiprocess.MultiProcessCollector(registry) + return registry + + return REGISTRY + + +def unregister_vllm_metrics(): + """Unregister any existing vLLM collectors from the prometheus registry. + + This is useful for testing and CI/CD where metrics may be registered + multiple times across test runs. + """ + registry = get_prometheus_registry() + # Unregister any existing vLLM collectors + for collector in list(registry._collector_to_names): + if hasattr(collector, "_name") and "vllm" in collector._name: + registry.unregister(collector) + + +def shutdown_prometheus(): + """Shutdown prometheus metrics.""" + try: + pid = os.getpid() + multiprocess.mark_process_dead(pid) + logger.debug("Marked Prometheus metrics for process %d as dead", pid) + except Exception as e: + logger.error("Error during metrics cleanup: %s", str(e)) From 97a6e767553bce20572d78679262cd4f4357fdc3 Mon Sep 17 00:00:00 2001 From: kourosh hakhamaneshi <31483498+kouroshHakha@users.noreply.github.com> Date: Sun, 18 May 2025 09:15:04 -0700 Subject: [PATCH 18/22] Graceful lifecyle management for API server and DP coordinator processes in case of api_server > 1 Signed-off-by: kouroshhakha --- .../test_api_server_process_manager.py | 267 ++++++++++++++++++ vllm/entrypoints/cli/serve.py | 54 ++-- vllm/v1/engine/coordinator.py | 6 +- vllm/v1/metrics/prometheus.py | 5 +- vllm/v1/utils.py | 131 ++++++++- 5 files changed, 431 insertions(+), 32 deletions(-) create mode 100644 tests/entrypoints/test_api_server_process_manager.py diff --git a/tests/entrypoints/test_api_server_process_manager.py b/tests/entrypoints/test_api_server_process_manager.py new file mode 100644 index 00000000000..4719bbd24e1 --- /dev/null +++ b/tests/entrypoints/test_api_server_process_manager.py @@ -0,0 +1,267 @@ +# SPDX-License-Identifier: Apache-2.0 + +import multiprocessing +import socket +import threading +import time +from unittest.mock import patch + +import pytest + +from vllm.v1.utils import (APIServerProcessManager, + wait_for_completion_or_failure) + +# Global variables to control worker behavior +WORKER_RUNTIME_SECONDS = 0.5 + + +# Mock implementation of run_api_server_worker +def mock_run_api_server_worker(listen_address, sock, args, client_config=None): + """Mock run_api_server_worker that runs for a specific time.""" + print(f"Mock worker started with client_config: {client_config}") + time.sleep(WORKER_RUNTIME_SECONDS) + print("Mock worker completed successfully") + + +@pytest.fixture +def api_server_args(): + """Fixture to provide arguments for APIServerProcessManager.""" + sock = socket.socket() + return { + "target_server_fn": + mock_run_api_server_worker, + "listen_address": + "localhost:8000", + "sock": + sock, + "args": + "test_args", # Simple string to avoid pickling issues + "num_servers": + 3, + "input_addresses": [ + "tcp://127.0.0.1:5001", "tcp://127.0.0.1:5002", + "tcp://127.0.0.1:5003" + ], + "output_addresses": [ + "tcp://127.0.0.1:6001", "tcp://127.0.0.1:6002", + "tcp://127.0.0.1:6003" + ], + "stats_update_address": + "tcp://127.0.0.1:7000", + } + + +@pytest.mark.parametrize("with_stats_update", [True, False]) +def test_api_server_process_manager_init(api_server_args, with_stats_update): + """Test initializing the APIServerProcessManager.""" + # Set the worker runtime to ensure tests complete in reasonable time + global WORKER_RUNTIME_SECONDS + WORKER_RUNTIME_SECONDS = 0.5 + + # Copy the args to avoid mutating the + args = api_server_args.copy() + + if not with_stats_update: + args.pop("stats_update_address") + manager = APIServerProcessManager(**args) + + try: + # Verify the manager was initialized correctly + assert len(manager.processes) == 3 + + # Verify all processes are running + for proc in manager.processes: + assert proc.is_alive() + + print("Waiting for processes to run...") + time.sleep(WORKER_RUNTIME_SECONDS / 2) + + # They should still be alive at this point + for proc in manager.processes: + assert proc.is_alive() + + finally: + # Always clean up the processes + print("Cleaning up processes...") + manager.close() + + # Give processes time to terminate + time.sleep(0.2) + + # Verify all processes were terminated + for proc in manager.processes: + assert not proc.is_alive() + + +@patch("vllm.entrypoints.cli.serve.run_api_server_worker", + mock_run_api_server_worker) +def test_wait_for_completion_or_failure(api_server_args): + """Test that wait_for_completion_or_failure works with failures.""" + global WORKER_RUNTIME_SECONDS + WORKER_RUNTIME_SECONDS = 1.0 + + # Create the manager + manager = APIServerProcessManager(**api_server_args) + + try: + assert len(manager.processes) == 3 + + # Create a result capture for the thread + result = {"exception": None} + + def run_with_exception_capture(): + try: + wait_for_completion_or_failure(api_server_manager=manager) + except Exception as e: + result["exception"] = e + + # Start a thread to run wait_for_completion_or_failure + wait_thread = threading.Thread(target=run_with_exception_capture, + daemon=True) + wait_thread.start() + + # Let all processes run for a short time + time.sleep(0.2) + + # All processes should still be running + assert all(proc.is_alive() for proc in manager.processes) + + # Now simulate a process failure + print("Simulating process failure...") + manager.processes[0].terminate() + + # Wait for the wait_for_completion_or_failure + # to detect and handle the failure + # This should trigger it to terminate all other processes + wait_thread.join(timeout=1.0) + + # The wait thread should have exited + assert not wait_thread.is_alive() + + # Verify that an exception was raised with appropriate error message + assert result["exception"] is not None + assert "died with exit code" in str(result["exception"]) + + # All processes should now be terminated + for i, proc in enumerate(manager.processes): + assert not proc.is_alive(), f"Process {i} should not be alive" + + finally: + manager.close() + time.sleep(0.2) + + +@pytest.mark.timeout(30) +def test_normal_completion(api_server_args): + """Test that wait_for_completion_or_failure works in normal completion.""" + global WORKER_RUNTIME_SECONDS + WORKER_RUNTIME_SECONDS = 0.1 + + # Create the manager + manager = APIServerProcessManager(**api_server_args) + + try: + # Give processes time to terminate + # wait for processes to complete + remaining_processes = manager.processes.copy() + while remaining_processes: + for proc in remaining_processes: + if not proc.is_alive(): + remaining_processes.remove(proc) + time.sleep(0.1) + + # Verify all processes have terminated + for i, proc in enumerate(manager.processes): + assert not proc.is_alive( + ), f"Process {i} still alive after terminate()" + + # Now call wait_for_completion_or_failure + # since all processes have already + # terminated, it should return immediately + # with no error + wait_for_completion_or_failure(api_server_manager=manager) + + finally: + # Clean up just in case + manager.close() + time.sleep(0.2) + + +@pytest.mark.timeout(30) +def test_external_process_monitoring(api_server_args): + """Test that wait_for_completion_or_failure handles additional processes.""" + global WORKER_RUNTIME_SECONDS + WORKER_RUNTIME_SECONDS = 100 + + # Create and start the external process + # (simulates local_engine_manager or coordinator) + spawn_context = multiprocessing.get_context("spawn") + external_proc = spawn_context.Process(target=mock_run_api_server_worker, + name="MockExternalProcess") + external_proc.start() + + # Create the class to simulate a coordinator + class MockCoordinator: + + def __init__(self, proc): + self.proc = proc + + def close(self): + if self.proc.is_alive(): + self.proc.terminate() + self.proc.join(timeout=0.5) + + # Create a mock coordinator with the external process + mock_coordinator = MockCoordinator(external_proc) + + # Create the API server manager + manager = APIServerProcessManager(**api_server_args) + + try: + # Verify manager initialization + assert len(manager.processes) == 3 + + # Create a result capture for the thread + result = {"exception": None} + + def run_with_exception_capture(): + try: + wait_for_completion_or_failure(api_server_manager=manager, + coordinator=mock_coordinator) + except Exception as e: + result["exception"] = e + + # Start a thread to run wait_for_completion_or_failure + wait_thread = threading.Thread(target=run_with_exception_capture, + daemon=True) + wait_thread.start() + + # Terminate the external process to trigger a failure + time.sleep(0.2) + external_proc.terminate() + + # Wait for the thread to detect the failure + wait_thread.join(timeout=1.0) + + # The wait thread should have completed + assert not wait_thread.is_alive( + ), "wait_for_completion_or_failure thread still running" + + # Verify that an exception was raised with appropriate error message + assert result["exception"] is not None, "No exception was raised" + error_message = str(result["exception"]) + assert "died with exit code" in error_message, \ + f"Unexpected error message: {error_message}" + assert "MockExternalProcess" in error_message, \ + f"Error doesn't mention external process: {error_message}" + + # Verify that all API server processes were terminated as a result + for i, proc in enumerate(manager.processes): + assert not proc.is_alive( + ), f"API server process {i} was not terminated" + + finally: + # Clean up + manager.close() + mock_coordinator.close() + time.sleep(0.2) diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index 4fe1d3a780c..aa9f553cf4d 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -1,12 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 import argparse -import multiprocessing import os import signal +import socket import sys -from multiprocessing.context import SpawnProcess -from typing import Any +from multiprocessing import connection +from typing import Any, Union import uvloop import zmq @@ -27,11 +27,16 @@ from vllm.v1.engine.core_client import CoreEngineProcManager from vllm.v1.executor.abstract import Executor from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus -from vllm.v1.utils import (CoreEngine, get_engine_client_zmq_addr, +from vllm.v1.utils import (APIServerProcessManager, CoreEngine, + get_engine_client_zmq_addr, + wait_for_completion_or_failure, wait_for_engine_startup) logger = init_logger(__name__) +# Type for process sentinel objects +SentinelType = Union[connection.Connection, socket.socket, int] + class ServeSubcommand(CLISubcommand): """The `serve` subcommand for the vLLM CLI. """ @@ -204,6 +209,8 @@ def run_multi_api_server(args: argparse.Namespace): coordinator = DPCoordinator(parallel_config) addresses.update(coordinator.get_engine_socket_addresses()) stats_update_address = coordinator.get_stats_publish_address() + logger.info("Started DP Coordinator process (PID: %d)", + coordinator.proc.pid) handshake_address = get_engine_client_zmq_addr( local_only, host, parallel_config.data_parallel_rpc_port) @@ -226,33 +233,22 @@ def run_multi_api_server(args: argparse.Namespace): start_index=0, local_start_index=0) - # Start API servers. - spawn_context = multiprocessing.get_context("spawn") - api_server_workers: list[SpawnProcess] = [] - for i, in_addr, out_addr in zip(range(num_api_servers), - input_addresses, output_addresses): - client_config = { - "input_address": in_addr, - "output_address": out_addr, - "client_index": i - } - if stats_update_address is not None: - client_config["stats_update_address"] = stats_update_address - - # TODO check signal propagation - proc = spawn_context.Process(target=run_api_server_worker, - name=f"ApiServer_{i}", - args=(listen_address, sock, args, - client_config)) - api_server_workers.append(proc) - proc.start() + # Start API servers using the manager + api_server_manager = APIServerProcessManager( + target_server_fn=run_api_server_worker, + listen_address=listen_address, + sock=sock, + args=args, + num_servers=num_api_servers, + input_addresses=input_addresses, + output_addresses=output_addresses, + stats_update_address=stats_update_address) # Wait for engine handshakes to complete. core_engines = [ CoreEngine(index=i, local=(i < local_engine_count)) for i in range(dp_size) ] - wait_for_engine_startup( handshake_socket, addresses, @@ -263,9 +259,11 @@ def run_multi_api_server(args: argparse.Namespace): coordinator.proc if coordinator else None, ) - # TODO handle failures / clean shutdown here - for proc in api_server_workers: - proc.join() + # Wait for API servers + wait_for_completion_or_failure( + api_server_manager=api_server_manager, + local_engine_manager=local_engine_manager, + coordinator=coordinator) def run_api_server_worker(listen_address, diff --git a/vllm/v1/engine/coordinator.py b/vllm/v1/engine/coordinator.py index ef5402196c6..fa5b13ba23c 100644 --- a/vllm/v1/engine/coordinator.py +++ b/vllm/v1/engine/coordinator.py @@ -2,6 +2,7 @@ import multiprocessing import sys import time +import weakref from typing import Optional import msgspec.msgpack @@ -12,7 +13,7 @@ from vllm.utils import get_mp_context, get_open_zmq_ipc_path, make_zmq_socket from vllm.v1.engine import EngineCoreOutputs, EngineCoreRequestType from vllm.v1.serial_utils import MsgpackDecoder -from vllm.v1.utils import get_engine_client_zmq_addr +from vllm.v1.utils import get_engine_client_zmq_addr, shutdown logger = init_logger(__name__) @@ -48,6 +49,7 @@ def __init__(self, parallel_config: ParallelConfig): self.stats_publish_address = front_publish_address self.coord_in_address = back_publish_address self.coord_out_address = back_output_address + self._finalizer = weakref.finalize(self, shutdown, [self.proc]) def get_stats_publish_address(self) -> str: return self.stats_publish_address @@ -59,7 +61,7 @@ def get_engine_socket_addresses(self) -> dict[str, str]: } def close(self): - self.proc.terminate() + self._finalizer() class EngineState: diff --git a/vllm/v1/metrics/prometheus.py b/vllm/v1/metrics/prometheus.py index c958d7cd31f..f1256853536 100644 --- a/vllm/v1/metrics/prometheus.py +++ b/vllm/v1/metrics/prometheus.py @@ -56,8 +56,11 @@ def unregister_vllm_metrics(): This is useful for testing and CI/CD where metrics may be registered multiple times across test runs. + + Also, in case of multiprocess, we need to unregister the metrics from the + global registry. """ - registry = get_prometheus_registry() + registry = REGISTRY # Unregister any existing vLLM collectors for collector in list(registry._collector_to_names): if hasattr(collector, "_name") and "vllm" in collector._name: diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 81d08d3e963..0aa6da820a0 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -1,13 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 +import argparse +import multiprocessing import time import weakref from collections import defaultdict from collections.abc import Sequence from enum import Enum, auto from multiprocessing import Process, connection +from multiprocessing.context import SpawnProcess from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar, - Union, overload) + Union, cast, overload) import msgspec import torch @@ -24,6 +27,7 @@ if TYPE_CHECKING: from vllm.attention.layer import Attention + from vllm.v1.engine.coordinator import DPCoordinator logger = init_logger(__name__) @@ -107,6 +111,74 @@ def get_engine_client_zmq_addr(local_only: bool, host, port or get_open_port())) +class APIServerProcessManager: + """Manages a group of API server processes. + + Handles creation, monitoring, and termination of API server worker + processes. Also monitors extra processes to check if they are healthy. + """ + + def __init__( + self, + target_server_fn: Callable, + listen_address: str, + sock: Any, + args: argparse.Namespace, + num_servers: int, + input_addresses: list[str], + output_addresses: list[str], + stats_update_address: Optional[str] = None, + ): + """Initialize and start API server worker processes. + + Args: + target_server_fn: Function to call for each API server process + listen_address: Address to listen for client connections + sock: Socket for client connections + args: Command line arguments + num_servers: Number of API server processes to start + input_addresses: Input addresses for each API server + output_addresses: Output addresses for each API server + stats_update_address: Optional stats update address + """ + self.listen_address = listen_address + self.sock = sock + self.args = args + + # Start API servers + spawn_context = multiprocessing.get_context("spawn") + self.processes: list[SpawnProcess] = [] + + for i, in_addr, out_addr in zip(range(num_servers), input_addresses, + output_addresses): + client_config = { + "input_address": in_addr, + "output_address": out_addr, + "client_index": i + } + if stats_update_address is not None: + client_config["stats_update_address"] = stats_update_address + + proc = spawn_context.Process(target=target_server_fn, + name=f"ApiServer_{i}", + args=(listen_address, sock, args, + client_config)) + self.processes.append(proc) + proc.start() + + logger.info("Started %d API server processes", len(self.processes)) + + # Casting due to mypy error + process_list: list[multiprocessing.Process] = cast( + list[multiprocessing.Process], self.processes) + # Shutdown only the API server processes on garbage collection + # The extra processes are managed by their owners + self._finalizer = weakref.finalize(self, shutdown, process_list) + + def close(self) -> None: + self._finalizer() + + class CoreEngineProcManager: """ Utility class to handle creation, readiness, and shutdown @@ -288,6 +360,63 @@ def wait_for_engine_startup( "local" if local else "remote", eng_index) +def wait_for_completion_or_failure( + api_server_manager: APIServerProcessManager, + local_engine_manager: Optional[CoreEngineProcManager] = None, + coordinator: Optional["DPCoordinator"] = None) -> None: + """Wait for all processes to complete or detect if any fail. + + Raises an exception if any process exits with a non-zero status. + """ + + try: + logger.info("Waiting for API servers to complete ...") + # Create a mapping of sentinels to their corresponding processes + # for efficient lookup + sentinel_to_proc: dict[Any, Union[SpawnProcess, Process]] = { + proc.sentinel: proc + for proc in api_server_manager.processes + } + + if coordinator: + sentinel_to_proc.update( + {coordinator.proc.sentinel: coordinator.proc}) + + if local_engine_manager: + sentinel_to_proc.update({ + proc.sentinel: proc + for proc in local_engine_manager.processes + }) + + # Check if any process terminates + while sentinel_to_proc: + # Wait for any process to terminate + ready_sentinels: list[Any] = connection.wait(sentinel_to_proc) + + # Process any terminated processes + for sentinel in ready_sentinels: + proc = sentinel_to_proc.pop(sentinel) + + # Check if process exited with error + if proc.exitcode != 0: + raise RuntimeError( + f"Process {proc.name} (PID: {proc.pid}) " + f"died with exit code {proc.exitcode}") + except KeyboardInterrupt: + logger.info("Received KeyboardInterrupt, shutting down API servers...") + except Exception as e: + logger.exception("Exception occurred while running API servers: %s", + str(e)) + raise + finally: + logger.info("Terminating remaining processes ...") + api_server_manager.close() + if coordinator: + coordinator.close() + if local_engine_manager: + local_engine_manager.close() + + # Note(rob): shutdown function cannot be a bound method, # else the gc cannot collect the object. def shutdown(procs: list[Process]): From a0c835ee101a32fb9fe80b2359589eeb84bf65ca Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Sun, 18 May 2025 15:42:08 -0700 Subject: [PATCH 19/22] Disable MM cache for api_server_count > 1 Signed-off-by: Nick Hill --- vllm/entrypoints/cli/serve.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index aa9f553cf4d..ea3da787c29 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -168,15 +168,24 @@ def run_multi_api_server(args: argparse.Namespace): assert not args.headless num_api_servers = args.api_server_count - assert num_api_servers > 1 + #assert num_api_servers > 1 - setup_multiprocess_prometheus() + if num_api_servers > 1: + setup_multiprocess_prometheus() listen_address, sock = setup_server(args) engine_args = AsyncEngineArgs.from_cli_args(args) usage_context = UsageContext.OPENAI_API_SERVER vllm_config = engine_args.create_engine_config(usage_context=usage_context) + model_config = vllm_config.model_config + + if num_api_servers > 1 and model_config.is_multimodal_model and not ( + model_config.disable_mm_preprocessor_cache): + logger.warning("Multi-model preprocessor cache will be disabled for" + " api_server_count > 1") + model_config.disable_mm_preprocessor_cache = True + parallel_config = vllm_config.parallel_config assert parallel_config.data_parallel_rank == 0 From f7afac693ce239077cbd9946ecc87733f6e83482 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Sun, 18 May 2025 15:53:54 -0700 Subject: [PATCH 20/22] Fix typing in test Signed-off-by: Nick Hill --- tests/entrypoints/test_api_server_process_manager.py | 5 +++-- vllm/v1/utils.py | 9 +++------ 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/tests/entrypoints/test_api_server_process_manager.py b/tests/entrypoints/test_api_server_process_manager.py index 4719bbd24e1..0dd1fdd9969 100644 --- a/tests/entrypoints/test_api_server_process_manager.py +++ b/tests/entrypoints/test_api_server_process_manager.py @@ -4,6 +4,7 @@ import socket import threading import time +from typing import Optional from unittest.mock import patch import pytest @@ -107,7 +108,7 @@ def test_wait_for_completion_or_failure(api_server_args): assert len(manager.processes) == 3 # Create a result capture for the thread - result = {"exception": None} + result: dict[str, Optional[Exception]] = {"exception": None} def run_with_exception_capture(): try: @@ -222,7 +223,7 @@ def close(self): assert len(manager.processes) == 3 # Create a result capture for the thread - result = {"exception": None} + result: dict[str, Optional[Exception]] = {"exception": None} def run_with_exception_capture(): try: diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 0aa6da820a0..3e1855d23a3 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -379,14 +379,11 @@ def wait_for_completion_or_failure( } if coordinator: - sentinel_to_proc.update( - {coordinator.proc.sentinel: coordinator.proc}) + sentinel_to_proc[coordinator.proc.sentinel] = coordinator.proc if local_engine_manager: - sentinel_to_proc.update({ - proc.sentinel: proc - for proc in local_engine_manager.processes - }) + for proc in local_engine_manager.processes: + sentinel_to_proc[proc.sentinel] = proc # Check if any process terminates while sentinel_to_proc: From 34c5eb9c2f39ebaa4a7108abe93a5e90f6cd9b2b Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Mon, 19 May 2025 14:01:29 -0400 Subject: [PATCH 21/22] Fix Process typing Signed-off-by: Nick Hill --- vllm/v1/utils.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 3e1855d23a3..a5beff6c454 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -8,9 +8,9 @@ from collections.abc import Sequence from enum import Enum, auto from multiprocessing import Process, connection -from multiprocessing.context import SpawnProcess +from multiprocessing.process import BaseProcess from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar, - Union, cast, overload) + Union, overload) import msgspec import torch @@ -147,7 +147,7 @@ def __init__( # Start API servers spawn_context = multiprocessing.get_context("spawn") - self.processes: list[SpawnProcess] = [] + self.processes: list[BaseProcess] = [] for i, in_addr, out_addr in zip(range(num_servers), input_addresses, output_addresses): @@ -168,12 +168,9 @@ def __init__( logger.info("Started %d API server processes", len(self.processes)) - # Casting due to mypy error - process_list: list[multiprocessing.Process] = cast( - list[multiprocessing.Process], self.processes) # Shutdown only the API server processes on garbage collection # The extra processes are managed by their owners - self._finalizer = weakref.finalize(self, shutdown, process_list) + self._finalizer = weakref.finalize(self, shutdown, self.processes) def close(self) -> None: self._finalizer() @@ -206,7 +203,7 @@ def __init__( "log_stats": log_stats, } - self.processes: list[Process] = [] + self.processes: list[BaseProcess] = [] for index in range(local_engine_count): local_index = local_start_index + index global_index = start_index + index @@ -373,7 +370,7 @@ def wait_for_completion_or_failure( logger.info("Waiting for API servers to complete ...") # Create a mapping of sentinels to their corresponding processes # for efficient lookup - sentinel_to_proc: dict[Any, Union[SpawnProcess, Process]] = { + sentinel_to_proc: dict[Any, BaseProcess] = { proc.sentinel: proc for proc in api_server_manager.processes } @@ -416,7 +413,7 @@ def wait_for_completion_or_failure( # Note(rob): shutdown function cannot be a bound method, # else the gc cannot collect the object. -def shutdown(procs: list[Process]): +def shutdown(procs: list[BaseProcess]): # Shutdown the process. for proc in procs: if proc.is_alive(): From c55cbb8c65fb99962038774f42eb9372e1ab8c61 Mon Sep 17 00:00:00 2001 From: Rui Qiao Date: Mon, 12 May 2025 22:29:16 +0000 Subject: [PATCH 22/22] [V1] Support DP with Ray Signed-off-by: Rui Qiao --- tests/v1/test_async_llm_dp.py | 16 ++- vllm/config.py | 6 ++ vllm/engine/arg_utils.py | 9 ++ vllm/entrypoints/cli/serve.py | 32 +++++- vllm/v1/engine/async_llm.py | 15 ++- vllm/v1/engine/core.py | 105 +++++++++++++++++++ vllm/v1/engine/core_client.py | 187 +++++++++++++++++++++++++++++++++- vllm/v1/utils.py | 181 ++++++++++++++++++++++++++++++++ 8 files changed, 536 insertions(+), 15 deletions(-) diff --git a/tests/v1/test_async_llm_dp.py b/tests/v1/test_async_llm_dp.py index ce4c4d198db..49dc1cf8e06 100644 --- a/tests/v1/test_async_llm_dp.py +++ b/tests/v1/test_async_llm_dp.py @@ -19,8 +19,9 @@ model="ibm-research/PowerMoE-3b", enforce_eager=True, disable_log_requests=True, - tensor_parallel_size=int(os.getenv("TP_SIZE", 1)), + tensor_parallel_size=int(os.getenv("TP_SIZE", 2)), data_parallel_size=int(os.getenv("DP_SIZE", 2)), + data_parallel_address="172.31.15.128", ) if not current_platform.supports_v1(engine_args.create_model_config()): @@ -59,14 +60,22 @@ async def generate(engine: AsyncLLM, @pytest.mark.parametrize( - "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]) + "output_kind", + [ + RequestOutputKind.DELTA, + # RequestOutputKind.FINAL_ONLY, + ], +) +@pytest.mark.parametrize("data_parallel_backend", ["ray"]) @pytest.mark.asyncio -async def test_load(output_kind: RequestOutputKind): +async def test_load(output_kind: RequestOutputKind, + data_parallel_backend: str): with ExitStack() as after: prompt = "This is a test of data parallel" + engine_args.data_parallel_backend = data_parallel_backend engine = AsyncLLM.from_engine_args(engine_args) after.callback(engine.shutdown) @@ -82,7 +91,6 @@ async def test_load(output_kind: RequestOutputKind): asyncio.create_task( generate(engine, request_id, prompt, output_kind, NUM_EXPECTED_TOKENS))) - # Confirm that we got all the EXPECTED tokens from the requests. done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_EXCEPTION) diff --git a/vllm/config.py b/vllm/config.py index a185a75c6bf..183a8ac3912 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1693,6 +1693,8 @@ class ParallelConfig: """Port for data parallel messaging.""" data_parallel_master_port: int = 29500 """Port of the data parallel master.""" + data_parallel_backend: str = "mp" + """Backend to use for data parallel, either "mp" or "ray".""" enable_expert_parallel: bool = False """Use expert parallelism instead of tensor parallelism for MoE layers.""" max_parallel_loading_workers: Optional[int] = None @@ -1856,6 +1858,10 @@ def __post_init__(self) -> None: "please install Ray with `pip install " "ray`.") from ray_utils.ray_import_err backend = "ray" + elif self.data_parallel_backend == "ray": + logger.info("Using ray distributed inference because " + "data_parallel_backend is ray") + backend = "ray" elif ray_found: if self.placement_group: backend = "ray" diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index f0c6b15b79d..3c76bac396b 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -290,6 +290,7 @@ class EngineArgs: data_parallel_size_local: Optional[int] = None data_parallel_address: Optional[str] = None data_parallel_rpc_port: Optional[int] = None + data_parallel_backend: str = ParallelConfig.data_parallel_backend enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel max_parallel_loading_workers: Optional[ int] = ParallelConfig.max_parallel_loading_workers @@ -618,6 +619,11 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: type=int, help='Port for data parallel RPC ' 'communication.') + parallel_group.add_argument('--data-parallel-backend', + '-dpb', + type=str, + help='Backend for data parallel, either ' + '"mp" or "ray".') parallel_group.add_argument( "--enable-expert-parallel", **parallel_kwargs["enable_expert_parallel"]) @@ -1058,6 +1064,8 @@ def create_engine_config( self.data_parallel_rpc_port is not None) else ParallelConfig.data_parallel_rpc_port + data_parallel_backend = self.data_parallel_backend + parallel_config = ParallelConfig( pipeline_parallel_size=self.pipeline_parallel_size, tensor_parallel_size=self.tensor_parallel_size, @@ -1065,6 +1073,7 @@ def create_engine_config( data_parallel_size_local=data_parallel_size_local, data_parallel_master_ip=data_parallel_address, data_parallel_rpc_port=data_parallel_rpc_port, + data_parallel_backend=data_parallel_backend, enable_expert_parallel=self.enable_expert_parallel, max_parallel_loading_workers=self.max_parallel_loading_workers, disable_custom_all_reduce=self.disable_custom_all_reduce, diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index ea3da787c29..f7362499fe1 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -28,9 +28,9 @@ from vllm.v1.executor.abstract import Executor from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus from vllm.v1.utils import (APIServerProcessManager, CoreEngine, - get_engine_client_zmq_addr, + CoreEngineActorManager, get_engine_client_zmq_addr, wait_for_completion_or_failure, - wait_for_engine_startup) + wait_for_engine_startup, wait_for_ray_engine_actors) logger = init_logger(__name__) @@ -221,6 +221,34 @@ def run_multi_api_server(args: argparse.Namespace): logger.info("Started DP Coordinator process (PID: %d)", coordinator.proc.pid) + if parallel_config.data_parallel_backend == "ray": + logger.info("Starting ray-based data parallel backend") + + engine_actor_manager = CoreEngineActorManager( + local_engine_count=local_engine_count, + start_index=args.data_parallel_start_rank, + local_start_index=0, + vllm_config=vllm_config, + addresses=addresses, + executor_class=Executor.get_class(vllm_config), + log_stats=not engine_args.disable_log_stats, + ) + # Start API servers using the manager + api_server_manager = APIServerProcessManager( + target_server_fn=run_api_server_worker, + listen_address=listen_address, + sock=sock, + args=args, + num_servers=num_api_servers, + input_addresses=input_addresses, + output_addresses=output_addresses, + stats_update_address=stats_update_address) + + wait_for_ray_engine_actors(api_server_manager=api_server_manager, + engine_actor_manager=engine_actor_manager, + coordinator=coordinator) + return + handshake_address = get_engine_client_zmq_addr( local_only, host, parallel_config.data_parallel_rpc_port) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 7026793befb..5fb00df80c7 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -25,7 +25,8 @@ from vllm.usage.usage_lib import UsageContext from vllm.utils import Device, cdiv from vllm.v1.engine import EngineCoreRequest -from vllm.v1.engine.core_client import AsyncMPClient, DPAsyncMPClient +from vllm.v1.engine.core_client import (AsyncMPClient, DPAsyncMPClient, + RayDPClient) from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError from vllm.v1.engine.output_processor import (OutputProcessor, RequestOutputCollector) @@ -114,9 +115,15 @@ def __init__( log_stats=self.log_stats) # EngineCore (starts the engine in background process). - core_client_class = AsyncMPClient if ( - vllm_config.parallel_config.data_parallel_size - == 1) else DPAsyncMPClient + core_client_class: Union[type[RayDPClient], type[DPAsyncMPClient], + type[AsyncMPClient]] + if vllm_config.parallel_config.data_parallel_size > 1: + if vllm_config.parallel_config.data_parallel_backend == "ray": + core_client_class = RayDPClient + else: + core_client_class = DPAsyncMPClient + else: + core_client_class = AsyncMPClient self.engine_core = core_client_class( vllm_config=vllm_config, diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 3e35a910a9d..c44ccf8dfa4 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -866,3 +866,108 @@ def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool: return ParallelConfig.has_unfinished_dp(self.dp_group, local_unfinished) + + +class DPEngineCoreActor(DPEngineCoreProc): + """ + Ray Actor for running EngineCore in a data parallel context + """ + + def __init__( + self, + vllm_config: VllmConfig, + on_head_node: bool, + addresses, + executor_class: type[Executor], + log_stats: bool, + engine_index: int = 0, + dp_rank: int = 0, + local_dp_rank: int = 0, + ): + # TODO(rui): improve shutdown handling + + # Ensure we can serialize transformer config after spawning + maybe_register_config_serialize_by_value() + + parallel_config: ParallelConfig = vllm_config.parallel_config + assert parallel_config.data_parallel_size > 1 or dp_rank > 0 + # Set data parallel rank for this engine process. + parallel_config.data_parallel_rank = dp_rank + parallel_config.data_parallel_rank_local = local_dp_rank + + input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]() + + executor_fail_callback = lambda: input_queue.put_nowait( + (EngineCoreRequestType.EXECUTOR_FAILED, b'')) + + input_addresses: list[str] = addresses["input_addresses"] + output_addresses: list[str] = addresses["output_addresses"] + coord_in_addr: Optional[str] = addresses.get("coord_in_address") + coord_out_addr: Optional[str] = addresses.get("coord_out_address") + self.client_count = len(output_addresses) + self.coordinator = coord_out_addr is not None + + # Ray sets CUDA_VISIBLE_DEVICES to empty string, + # we clean this up to be able to properly initialize + # data parallel groups. + del os.environ['CUDA_VISIBLE_DEVICES'] + # Set up data parallel environment. + self._init_data_parallel(vllm_config) + + # Counts forward-passes of the model so that we can synchronize + # finished with DP peers every N steps. + self.counter = 0 + self.current_wave = 0 + + # Initialize engine core and model. + EngineCore.__init__(self, vllm_config, executor_class, log_stats, + executor_fail_callback) + + self.engine_index = engine_index + self.step_fn = (self.step if self.batch_queue is None else + self.step_with_batch_queue) + self.engines_running = False + self.last_counts = (0, 0) + + # Background Threads and Queues for IO. These enable us to + # overlap ZMQ socket IO with GPU since they release the GIL, + # and to overlap some serialization/deserialization with the + # model forward pass. + # Threads handle Socket <-> Queues and core_busy_loop uses Queue. + self.input_queue = input_queue + self.output_queue = queue.Queue[Union[tuple[int, EngineCoreOutputs], + bytes]]() + identity = engine_index.to_bytes(length=2, byteorder="little") + threading.Thread(target=self.process_input_sockets, + args=(input_addresses, coord_in_addr, identity), + daemon=True).start() + self.output_thread = threading.Thread( + target=self.process_output_sockets, + args=(output_addresses, coord_out_addr, engine_index), + daemon=True) + self.output_thread.start() + + def wait_for_init(self): + """ + Wait until the engine core is initialized. + + This is just an empty method. When ray.get() on this method + (or any other method of the actor) returns, it is guaranteed + that actor creation (i.e., __init__) is complete. + """ + pass + + def run(self): + """ + Run the engine core busy loop. + """ + try: + self.run_busy_loop() + except SystemExit: + logger.debug("EngineCore exiting.") + raise + except Exception: + logger.exception("EngineCore encountered a fatal error.") + raise + finally: + self.shutdown() diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 287cfe2b176..f52692ad478 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -29,8 +29,9 @@ from vllm.v1.engine.exceptions import EngineDeadError from vllm.v1.executor.abstract import Executor from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr -from vllm.v1.utils import (CoreEngine, CoreEngineProcManager, - get_engine_client_zmq_addr, wait_for_engine_startup) +from vllm.v1.utils import (CoreEngine, CoreEngineActorManager, + CoreEngineProcManager, get_engine_client_zmq_addr, + wait_for_engine_startup) logger = init_logger(__name__) @@ -67,7 +68,11 @@ def make_client( if multiprocess_mode and asyncio_mode: if vllm_config.parallel_config.data_parallel_size > 1: - return DPAsyncMPClient(vllm_config, executor_class, log_stats) + if vllm_config.parallel_config.data_parallel_backend == "ray": + return RayDPClient(vllm_config, executor_class, log_stats) + else: + return DPAsyncMPClient(vllm_config, executor_class, + log_stats) return AsyncMPClient(vllm_config, executor_class, log_stats) @@ -271,7 +276,8 @@ class BackgroundResources: circular reference back to the client object.""" ctx: Union[zmq.Context] - local_engine_manager: Optional[CoreEngineProcManager] = None + local_engine_manager: Optional[Union[CoreEngineProcManager, + CoreEngineActorManager]] = None coordinator: Optional[DPCoordinator] = None output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None @@ -480,13 +486,19 @@ def _wait_for_engine_startup(self, handshake_socket: zmq.Socket, if coordinator is not None: addresses.update(coordinator.get_engine_socket_addresses()) + proc_manager = self.resources.local_engine_manager + if proc_manager is not None: + assert isinstance(proc_manager, CoreEngineProcManager), ( + "_wait_for_engine_startup should only be " + "called with CoreEngineProcManager") + wait_for_engine_startup( handshake_socket, addresses, self.core_engines, self.vllm_config.parallel_config, self.vllm_config.cache_config, - self.resources.local_engine_manager, + proc_manager, coordinator.proc if coordinator else None, ) @@ -1045,3 +1057,168 @@ async def _abort_requests(self, request_ids: list[str], if not self.resources.engine_dead: await self._send_input(EngineCoreRequestType.ABORT, request_ids, engine) + + +class RayDPClient(DPAsyncMPClient): + """ + Ray-based client for multi-proc, multi-engine (data parallel) + EngineCore. + """ + + def __init__( + self, + vllm_config: VllmConfig, + executor_class: type[Executor], + log_stats: bool, + client_addresses: Optional[dict[str, str]] = None, + client_index: int = 0, + ): + self.client_index = client_index + self.current_wave = 0 + self.engines_running = False + self.reqs_in_flight: dict[str, CoreEngine] = {} + + self.vllm_config = vllm_config + # Serialization setup. + self.encoder = MsgpackEncoder() + self.decoder = MsgpackDecoder(EngineCoreOutputs) + + # ZMQ setup. + sync_ctx = zmq.Context(io_threads=2) + self.ctx = zmq.asyncio.Context(sync_ctx) + + # List of [waiting, running] pair per engine. + self.lb_engines: list[list[int]] = [] + self.first_req_sock_addr = get_open_zmq_inproc_path() + self.first_req_send_socket = make_zmq_socket(self.ctx, + self.first_req_sock_addr, + zmq.PAIR, + bind=True) + + # This will ensure resources created so far are closed + # when the client is garbage collected, even if an + # exception is raised mid-construction. + self.resources = BackgroundResources(ctx=sync_ctx) + self._finalizer = weakref.finalize(self, self.resources) + success = False + try: + parallel_config = vllm_config.parallel_config + local_engine_count = parallel_config.data_parallel_size_local + start_index = parallel_config.data_parallel_rank + local_start_index = parallel_config.data_parallel_rank_local + + # SPMD mode is where there is an LLM instance per DP rank and + # one core engine per LLM, see + # examples/offline_inference/data_parallel.py. + spmd_mode = local_start_index is not None + if spmd_mode: + assert local_engine_count == 1 + self.core_engines = [ + CoreEngine(index=local_start_index, local=True) + ] + else: + assert start_index == 0 + local_start_index = 0 + self.core_engines = [ + CoreEngine(index=i, local=(i < local_engine_count)) + for i in range(parallel_config.data_parallel_size) + ] + + dp_size = parallel_config.data_parallel_size + local_only = spmd_mode or local_engine_count == dp_size + + self.stats_update_address: Optional[str] = None + if client_addresses is not None: + input_address = client_addresses["input_address"] + output_address = client_addresses["output_address"] + self.stats_update_address = client_addresses.get( + "stats_update_address") + else: + host = parallel_config.data_parallel_master_ip + input_address = get_engine_client_zmq_addr(local_only, host) + output_address = get_engine_client_zmq_addr(local_only, host) + + # Create input and output sockets. + self.input_socket = self.resources.input_socket = make_zmq_socket( + self.ctx, input_address, zmq.ROUTER, bind=True) + self.resources.output_socket = make_zmq_socket( + self.ctx, output_address, zmq.PULL) + + if client_addresses is None: + self._init_engines_direct(vllm_config, local_only, + local_start_index, input_address, + output_address, executor_class, + log_stats) + coordinator = self.resources.coordinator + if coordinator: + self.stats_update_address = \ + coordinator.get_stats_publish_address() + + # Wait for ready messages from each engine on the input socket. + identities = set(e.identity for e in self.core_engines) + sync_input_socket = zmq.Socket.shadow(self.input_socket) + while identities: + if not sync_input_socket.poll(timeout=600_000): + raise TimeoutError("Timed out waiting for engines to send" + "initial message on input socket.") + identity, _ = sync_input_socket.recv_multipart() + identities.remove(identity) + + self.core_engine = self.core_engines[0] + + self.utility_results: dict[int, AnyFuture] = {} + + # Request objects which may contain pytorch-allocated tensors + # that we need to keep references to until zmq is done with the + # underlying data. + self.pending_messages = deque[tuple[zmq.MessageTracker, Any]]() + self.outputs_queue = asyncio.Queue[Union[EngineCoreOutputs, + Exception]]() + + success = True + finally: + if not success: + self._finalizer() + + try: + # If we are running in an asyncio event loop, start the queue task. + # Otherwise, it will be started lazily. If it is not started here, + # we could miss EXECUTOR_FAILED messages from engine core if they + # occur prior to any requests being sent. + asyncio.get_running_loop() + self._ensure_output_queue_task() + except RuntimeError: + pass + + def _init_engines_direct(self, vllm_config: VllmConfig, local_only: bool, + local_start_index: int, input_address: str, + output_address: str, + executor_class: type[Executor], log_stats: bool): + """Self-contained client mode, launch engine and coordinator process + as needed.""" + + parallel_config = vllm_config.parallel_config + local_engine_count = parallel_config.data_parallel_size_local + start_index = parallel_config.data_parallel_rank + + if len(self.core_engines) > 1: + self.resources.coordinator = DPCoordinator(parallel_config) + + addresses: dict[str, Any] = { + "input_addresses": [input_address], + "output_addresses": [output_address], + } + + coordinator = self.resources.coordinator + if coordinator is not None: + addresses.update(coordinator.get_engine_socket_addresses()) + + # Start all engines. + self.resources.local_engine_manager = CoreEngineActorManager( + vllm_config=vllm_config, + executor_class=executor_class, + log_stats=log_stats, + addresses=addresses, + local_engine_count=local_engine_count, + start_index=start_index, + local_start_index=local_start_index) diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index a5beff6c454..1470dad5b2b 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -244,6 +244,136 @@ def finished_procs(self) -> dict[str, int]: } +class CoreEngineActorManager: + """ + Utility class to handle creation, readiness, and shutdown + of core engine Ray actors used by the AsyncLLM and LLMEngine. + + Different from CoreEngineProcManager, this class manages + core engines for both local and remote nodes. + """ + + def __init__( + self, + local_engine_count: int, + start_index: int, + local_start_index: int, + vllm_config: VllmConfig, + addresses, + executor_class: type[Executor], + log_stats: bool, + ): + import copy + + import ray + from ray._private.state import available_resources_per_node + from ray.util.scheduling_strategies import ( + PlacementGroupSchedulingStrategy) + from ray.util.state import list_nodes + + from vllm.v1.engine.core import DPEngineCoreActor + + self.local_engine_actors: list[ray.ActorHandle] = [] + self.remote_engine_actors: list[ray.ActorHandle] = [] + + dp_size = vllm_config.parallel_config.data_parallel_size + remote_engine_count = dp_size - local_engine_count + + if ray.is_initialized(): + logger.info( + "Ray is already initialized. Skipping Ray initialization.") + else: + ray.init() + + nodes = list_nodes() + available_resources_by_id = available_resources_per_node() + available_resources_by_ip = {} + num_workers = vllm_config.parallel_config.world_size + + dp_size_available = 0 + for node in nodes: + node_ip = node.node_ip + node_id = node.node_id + node_resources = available_resources_by_id[node_id] + available_resources_by_ip[node_ip] = node_resources + # For now, each DP rank can only be assigned to one node + # TODO(rui): support allocating a single DP rank to multiple nodes + dp_size_available += node_resources["GPU"] // num_workers + + assert dp_size_available >= dp_size, ( + "Not enough resources to allocate DP ranks") + + head_node_ip = \ + vllm_config.parallel_config.data_parallel_master_ip + + refs = [] + for index in range(local_engine_count): + local_index = local_start_index + index + global_index = start_index + index + dp_vllm_config = copy.deepcopy(vllm_config) + bundles = [{ + "GPU": 1.0, + "node:" + head_node_ip: 0.001 + }] * num_workers + [{ + "CPU": 1.0 + }] + pg = ray.util.placement_group( + name=f"dp_rank_{global_index}", + strategy="STRICT_PACK", + bundles=bundles, + ) + dp_vllm_config.parallel_config.placement_group = pg + actor = ray.remote(DPEngineCoreActor).options( + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg, + placement_group_bundle_index=num_workers, + )).remote(vllm_config=dp_vllm_config, + executor_class=executor_class, + log_stats=log_stats, + addresses=addresses, + on_head_node=True, + engine_index=global_index, + dp_rank=global_index, + local_dp_rank=local_index) + self.local_engine_actors.append(actor) + refs.append(actor.wait_for_init.remote()) + + for index in range(remote_engine_count): + local_index = index + global_index = local_engine_count + index + bundles = [{"GPU": 1.0}] * num_workers + [{"CPU": 1.0}] + pg = ray.util.placement_group( + name=f"dp_rank_{global_index}", + strategy="STRICT_PACK", + bundles=bundles, + ) + dp_vllm_config = copy.deepcopy(vllm_config) + dp_vllm_config.parallel_config.placement_group = pg + actor = ray.remote(DPEngineCoreActor).options( + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg, + placement_group_bundle_index=num_workers, + )).remote(vllm_config=dp_vllm_config, + executor_class=executor_class, + log_stats=log_stats, + addresses=addresses, + on_head_node=False, + engine_index=global_index, + dp_rank=global_index, + local_dp_rank=local_index) + self.remote_engine_actors.append(actor) + refs.append(actor.wait_for_init.remote()) + + ray.get(refs) + for actor in self.local_engine_actors + self.remote_engine_actors: + actor.run.remote() + + def close(self): + import ray + for actor in self.local_engine_actors + self.remote_engine_actors: + ray.kill(actor) + + class CoreEngineState(Enum): NEW = auto() CONNECTED = auto() @@ -411,6 +541,57 @@ def wait_for_completion_or_failure( local_engine_manager.close() +def wait_for_ray_engine_actors( + api_server_manager: APIServerProcessManager, + engine_actor_manager: CoreEngineActorManager, + coordinator: Optional["DPCoordinator"] = None) -> None: + """Wait for all ray engine actors to complete or detect if any fail. + + Raises an exception if any process exits with a non-zero status. + """ + + try: + logger.info("Waiting for ray engine actors to complete ...") + # Create a mapping of sentinels to their corresponding processes + # for efficient lookup + sentinel_to_proc: dict[Any, Union[SpawnProcess, Process]] = { + proc.sentinel: proc + for proc in api_server_manager.processes + } + + if coordinator: + sentinel_to_proc.update( + {coordinator.proc.sentinel: coordinator.proc}) + + # TODO(rui): check if any ray engine actor terminates + # Check if any process terminates + while sentinel_to_proc: + # Wait for any process to terminate + ready_sentinels: list[Any] = connection.wait(sentinel_to_proc) + + # Process any terminated processes + for sentinel in ready_sentinels: + proc = sentinel_to_proc.pop(sentinel) + + # Check if process exited with error + if proc.exitcode != 0: + raise RuntimeError( + f"Process {proc.name} (PID: {proc.pid}) " + f"died with exit code {proc.exitcode}") + except KeyboardInterrupt: + logger.info("Received KeyboardInterrupt, shutting down API servers...") + except Exception as e: + logger.exception("Exception occurred while running API servers: %s", + str(e)) + raise + finally: + logger.info("Terminating remaining processes ...") + api_server_manager.close() + if coordinator: + coordinator.close() + engine_actor_manager.close() + + # Note(rob): shutdown function cannot be a bound method, # else the gc cannot collect the object. def shutdown(procs: list[BaseProcess]):