Skip to content

[V1] DP scale-out (1/N): Use zmq ROUTER/DEALER sockets for input queue #15906

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 25 additions & 10 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2189,6 +2189,8 @@ def make_zmq_socket(
ctx: Union[zmq.asyncio.Context, zmq.Context], # type: ignore[name-defined]
path: str,
socket_type: Any,
bind: Optional[bool] = None,
identity: Optional[bytes] = None,
) -> Union[zmq.Socket, zmq.asyncio.Socket]: # type: ignore[name-defined]
"""Make a ZMQ socket with the proper bind/connect semantics."""

Expand All @@ -2207,16 +2209,24 @@ def make_zmq_socket(
else:
buf_size = -1 # Use system default buffer size

if socket_type == zmq.constants.PULL:
socket.setsockopt(zmq.constants.RCVHWM, 0)
socket.setsockopt(zmq.constants.RCVBUF, buf_size)
if bind is None:
bind = socket_type != zmq.PUSH

if socket_type in (zmq.PULL, zmq.DEALER, zmq.ROUTER):
socket.setsockopt(zmq.RCVHWM, 0)
socket.setsockopt(zmq.RCVBUF, buf_size)

if socket_type in (zmq.PUSH, zmq.DEALER, zmq.ROUTER):
socket.setsockopt(zmq.SNDHWM, 0)
socket.setsockopt(zmq.SNDBUF, buf_size)

if identity is not None:
socket.setsockopt(zmq.IDENTITY, identity)

if bind:
socket.bind(path)
elif socket_type == zmq.constants.PUSH:
socket.setsockopt(zmq.constants.SNDHWM, 0)
socket.setsockopt(zmq.constants.SNDBUF, buf_size)
socket.connect(path)
else:
raise ValueError(f"Unknown Socket Type: {socket_type}")
socket.connect(path)

return socket

Expand All @@ -2225,14 +2235,19 @@ def make_zmq_socket(
def zmq_socket_ctx(
path: str,
socket_type: Any,
bind: Optional[bool] = None,
linger: int = 0,
identity: Optional[bytes] = None,
) -> Iterator[zmq.Socket]:
"""Context manager for a ZMQ socket"""

ctx = zmq.Context() # type: ignore[attr-defined]
try:
yield make_zmq_socket(ctx, path, socket_type)

yield make_zmq_socket(ctx,
path,
socket_type,
bind=bind,
identity=identity)
except KeyboardInterrupt:
logger.debug("Got Keyboard Interrupt.")

Expand Down
28 changes: 16 additions & 12 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,11 @@ def __init__(
):
super().__init__(vllm_config, executor_class, log_stats)

self.step_fn = (self.step if self.batch_queue is None else
self.step_with_batch_queue)

self.global_unfinished_reqs = False

# Background Threads and Queues for IO. These enable us to
# overlap ZMQ socket IO with GPU since they release the GIL,
# and to overlap some serialization/deserialization with the
Expand All @@ -317,22 +322,16 @@ def __init__(
Any]] = queue.Queue()
self.output_queue: queue.Queue[EngineCoreOutputs] = queue.Queue()
threading.Thread(target=self.process_input_socket,
args=(input_path, ),
args=(input_path, engine_index),
daemon=True).start()
threading.Thread(target=self.process_output_socket,
args=(output_path, engine_index),
daemon=True).start()

self.global_unfinished_reqs = False

self.step_fn = (self.step if self.batch_queue is None else
self.step_with_batch_queue)

@staticmethod
def run_engine_core(*args,
dp_rank: int = 0,
local_dp_rank: int = 0,
ready_pipe,
**kwargs):
"""Launch EngineCore busy loop in background process."""

Expand Down Expand Up @@ -367,9 +366,6 @@ def signal_handler(signum, frame):
else:
engine_core = EngineCoreProc(*args, **kwargs)

# Send Readiness signal to EngineClient.
ready_pipe.send({"status": "READY"})

engine_core.run_busy_loop()

except SystemExit:
Expand Down Expand Up @@ -466,14 +462,22 @@ def _convert_msgspec_args(method, args):
and not isinstance(v, p.annotation) else v
for v, p in zip(args, arg_types))

def process_input_socket(self, input_path: str):
def process_input_socket(self, input_path: str, engine_index: int):
"""Input socket IO thread."""

# Msgpack serialization decoding.
add_request_decoder = MsgpackDecoder(EngineCoreRequest)
generic_decoder = MsgpackDecoder()
identity = engine_index.to_bytes(length=2, byteorder="little")

with zmq_socket_ctx(input_path,
zmq.DEALER,
identity=identity,
bind=False) as socket:

# Send ready message to front-end once input socket is connected.
socket.send(b'READY')

with zmq_socket_ctx(input_path, zmq.constants.PULL) as socket:
while True:
# (RequestType, RequestData)
type_frame, data_frame = socket.recv_multipart(copy=False)
Expand Down
108 changes: 71 additions & 37 deletions vllm/v1/engine/core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import uuid
import weakref
from abc import ABC, abstractmethod
from collections.abc import Awaitable, Sequence
from collections.abc import Awaitable
from concurrent.futures import Future
from dataclasses import dataclass, field
from threading import Thread
Expand All @@ -35,6 +35,8 @@

_R = TypeVar('_R') # Return type for collective_rpc

STARTUP_POLL_PERIOD_MS = 10000


class EngineCoreClient(ABC):
"""
Expand Down Expand Up @@ -243,15 +245,13 @@ def __init__(
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
ctx: Union[zmq.Context, zmq.asyncio.Context],
input_path: str,
output_path: str,
index: int = 0,
local_dp_rank: int = 0,
):
# Paths and sockets for IPC.
input_path = get_open_zmq_ipc_path()
self.input_socket = make_zmq_socket(ctx, input_path,
zmq.constants.PUSH)
self.index = index
self.identity = index.to_bytes(length=2, byteorder="little")
try:
# Start EngineCore in background process.
self.proc_handle = BackgroundProcHandle(
Expand All @@ -273,14 +273,9 @@ def __init__(
# Ensure socket is closed if process fails to start.
self.close()

def send_multipart(self, msg_parts: Sequence):
return self.input_socket.send_multipart(msg_parts, copy=False)

def close(self):
if proc_handle := getattr(self, "proc_handle", None):
proc_handle.shutdown()
if socket := getattr(self, "input_socket", None):
socket.close(linger=0)


@dataclass
Expand All @@ -291,6 +286,7 @@ class BackgroundResources:
ctx: Union[zmq.Context]
core_engines: list[CoreEngine] = field(default_factory=list)
output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
shutdown_path: Optional[str] = None

def __call__(self):
Expand All @@ -303,6 +299,8 @@ def __call__(self):
# aren't explicitly closed first.
if self.output_socket is not None:
self.output_socket.close(linger=0)
if self.input_socket is not None:
self.input_socket.close(linger=0)
if self.shutdown_path is not None:
# We must ensure that the sync output socket is
# closed cleanly in its own thread.
Expand Down Expand Up @@ -369,21 +367,51 @@ def sigusr1_handler(signum, frame):

# Paths and sockets for IPC.
self.output_path = get_open_zmq_ipc_path()
input_path = get_open_zmq_ipc_path()
self.input_socket = make_zmq_socket(self.ctx,
input_path,
zmq.ROUTER,
bind=True)
self.resources.input_socket = self.input_socket

new_core_engine = lambda index, local_dp_rank=None: CoreEngine(
vllm_config, executor_class, log_stats, self.ctx, self.output_path,
index, local_dp_rank)
vllm_config, executor_class, log_stats, input_path, self.
output_path, index, local_dp_rank)

# Start engine core process(es).
self._init_core_engines(vllm_config, new_core_engine,
self.resources.core_engines)

# Wait for engine core process(es) to start.
for engine in self.resources.core_engines:
engine.proc_handle.wait_for_startup()
self._wait_for_engine_startup()

self.utility_results: dict[int, AnyFuture] = {}

def _wait_for_engine_startup(self):
# Get a sync handle to the socket which can be sync or async.
sync_input_socket = zmq.Socket.shadow(self.input_socket)

# Wait for engine core process(es) to send ready messages.
identities = set(eng.index for eng in self.resources.core_engines)
while identities:
while not sync_input_socket.poll(timeout=STARTUP_POLL_PERIOD_MS):
logger.info("Waiting for %d core engine proc(s) to start: %s",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens this run forever and never successful?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hongxiayang good point... the design here is for accommodating remote engines which might be started at different times, but we should still exit when one of the local engines fails. I'll make an update to address this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hongxiayang FYI here's the fix #16137

len(identities), identities)
eng_id_bytes, msg = sync_input_socket.recv_multipart()
eng_id = int.from_bytes(eng_id_bytes, byteorder="little")
if eng_id not in identities:
raise RuntimeError(f"Unexpected or duplicate engine: {eng_id}")
if msg != b'READY':
raise RuntimeError(f"Engine {eng_id} failed: {msg.decode()}")
logger.info("Core engine process %d ready.", eng_id)
identities.discard(eng_id)

# Double check that the process are running.
for engine in self.resources.core_engines:
proc = engine.proc_handle.proc
if proc.exitcode is not None:
raise RuntimeError(f"Engine proc {proc.name} not running")

def _init_core_engines(
self,
vllm_config: VllmConfig,
Expand Down Expand Up @@ -476,9 +504,10 @@ def get_output(self) -> EngineCoreOutputs:
return self.outputs_queue.get()

def _send_input(self, request_type: EngineCoreRequestType, request: Any):
# (RequestType, SerializedRequest)
msg = (request_type.value, self.encoder.encode(request))
self.core_engine.send_multipart(msg)
# (Identity, RequestType, SerializedRequest)
msg = (self.core_engine.identity, request_type.value,
self.encoder.encode(request))
self.input_socket.send_multipart(msg, copy=False)

def call_utility(self, method: str, *args) -> Any:
call_id = uuid.uuid1().int >> 64
Expand Down Expand Up @@ -601,30 +630,34 @@ async def get_output_async(self) -> EngineCoreOutputs:
assert self.outputs_queue is not None
return await self.outputs_queue.get()

async def _send_input(self, request_type: EngineCoreRequestType,
request: Any) -> None:
await self.core_engine.send_multipart(
(request_type.value, self.encoder.encode(request)))
def _send_input(self,
request_type: EngineCoreRequestType,
request: Any,
engine: Optional[CoreEngine] = None) -> Awaitable[None]:
if engine is None:
engine = self.core_engine

self._ensure_output_queue_task()
message = (request_type.value, self.encoder.encode(request))
return self._send_input_message(message, engine)

def _send_input_message(self, message: tuple[bytes, bytes],
engine: CoreEngine) -> Awaitable[None]:
message = (engine.identity, ) + message # type: ignore[assignment]
return self.input_socket.send_multipart(message, copy=False)

async def call_utility_async(self, method: str, *args) -> Any:
return await self._call_utility_async(method,
*args,
engine=self.core_engine)

async def _call_utility_async(
self,
method: str,
*args,
engine: CoreEngine,
) -> Any:
async def _call_utility_async(self, method: str, *args,
engine: CoreEngine) -> Any:
call_id = uuid.uuid1().int >> 64
future = asyncio.get_running_loop().create_future()
self.utility_results[call_id] = future
message = (EngineCoreRequestType.UTILITY.value,
self.encoder.encode((call_id, method, args)))
await engine.send_multipart(message)
await self._send_input_message(message, engine)
self._ensure_output_queue_task()
return await future

Expand All @@ -633,6 +666,7 @@ async def add_request_async(self, request: EngineCoreRequest) -> None:
# tokenized.
request.prompt = None
await self._send_input(EngineCoreRequestType.ADD, request)
self._ensure_output_queue_task()

async def abort_requests_async(self, request_ids: list[str]) -> None:
if len(request_ids) > 0:
Expand Down Expand Up @@ -730,15 +764,15 @@ async def add_request_async(self, request: EngineCoreRequest) -> None:
self.reqs_in_flight[request.request_id] = chosen_engine
chosen_engine.num_reqs_in_flight += 1
if self.num_engines_running >= len(self.core_engines):
await chosen_engine.send_multipart(msg)
await self._send_input_message(msg, chosen_engine)
else:
# Send request to chosen engine and dp start loop
# control message to all other engines.
self.num_engines_running += len(self.core_engines)
await asyncio.gather(*[
engine.send_multipart(msg if engine is
chosen_engine else self.start_dp_msg)
for engine in self.core_engines
self._send_input_message(
msg if engine is chosen_engine else self.start_dp_msg,
engine) for engine in self.core_engines
])

self._ensure_output_queue_task()
Expand All @@ -763,7 +797,7 @@ async def process_engine_outputs(self: "DPAsyncMPClient",
# sure to start the other engines:
self.num_engines_running = len(self.core_engines)
coros = [
engine.send_multipart(self.start_dp_msg)
self._send_input_message(self.start_dp_msg, engine)
for engine in self.core_engines
if not engine.num_reqs_in_flight
]
Expand All @@ -789,5 +823,5 @@ async def abort_requests_async(self, request_ids: list[str]) -> None:

async def _abort_requests(self, request_ids: list[str],
engine: CoreEngine) -> None:
await engine.send_multipart((EngineCoreRequestType.ABORT.value,
self.encoder.encode(request_ids)))
await self._send_input(EngineCoreRequestType.ABORT, request_ids,
engine)
11 changes: 1 addition & 10 deletions vllm/v1/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,9 @@ def __init__(
process_kwargs: dict[Any, Any],
):
context = get_mp_context()
self.reader, writer = context.Pipe(duplex=False)

assert ("ready_pipe" not in process_kwargs
and "input_path" not in process_kwargs
assert ("input_path" not in process_kwargs
and "output_path" not in process_kwargs)
process_kwargs["ready_pipe"] = writer
process_kwargs["input_path"] = input_path
process_kwargs["output_path"] = output_path

Expand All @@ -122,12 +119,6 @@ def __init__(
input_path, output_path)
self.proc.start()

def wait_for_startup(self):
# Wait for startup.
if self.reader.recv()["status"] != "READY":
raise RuntimeError(f"{self.proc.name} initialization failed. "
"See root cause above.")

def shutdown(self):
self._finalizer()

Expand Down