Skip to content

Commit bab5af3

Browse files
njhillnishith-fujitsu
authored andcommitted
[V1] DP scale-out (1/N): Use zmq ROUTER/DEALER sockets for input queue (vllm-project#15906)
Signed-off-by: Nick Hill <[email protected]>
1 parent 3c79adb commit bab5af3

File tree

4 files changed

+113
-69
lines changed

4 files changed

+113
-69
lines changed

vllm/utils.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2189,6 +2189,8 @@ def make_zmq_socket(
21892189
ctx: Union[zmq.asyncio.Context, zmq.Context], # type: ignore[name-defined]
21902190
path: str,
21912191
socket_type: Any,
2192+
bind: Optional[bool] = None,
2193+
identity: Optional[bytes] = None,
21922194
) -> Union[zmq.Socket, zmq.asyncio.Socket]: # type: ignore[name-defined]
21932195
"""Make a ZMQ socket with the proper bind/connect semantics."""
21942196

@@ -2207,16 +2209,24 @@ def make_zmq_socket(
22072209
else:
22082210
buf_size = -1 # Use system default buffer size
22092211

2210-
if socket_type == zmq.constants.PULL:
2211-
socket.setsockopt(zmq.constants.RCVHWM, 0)
2212-
socket.setsockopt(zmq.constants.RCVBUF, buf_size)
2212+
if bind is None:
2213+
bind = socket_type != zmq.PUSH
2214+
2215+
if socket_type in (zmq.PULL, zmq.DEALER, zmq.ROUTER):
2216+
socket.setsockopt(zmq.RCVHWM, 0)
2217+
socket.setsockopt(zmq.RCVBUF, buf_size)
2218+
2219+
if socket_type in (zmq.PUSH, zmq.DEALER, zmq.ROUTER):
2220+
socket.setsockopt(zmq.SNDHWM, 0)
2221+
socket.setsockopt(zmq.SNDBUF, buf_size)
2222+
2223+
if identity is not None:
2224+
socket.setsockopt(zmq.IDENTITY, identity)
2225+
2226+
if bind:
22132227
socket.bind(path)
2214-
elif socket_type == zmq.constants.PUSH:
2215-
socket.setsockopt(zmq.constants.SNDHWM, 0)
2216-
socket.setsockopt(zmq.constants.SNDBUF, buf_size)
2217-
socket.connect(path)
22182228
else:
2219-
raise ValueError(f"Unknown Socket Type: {socket_type}")
2229+
socket.connect(path)
22202230

22212231
return socket
22222232

@@ -2225,14 +2235,19 @@ def make_zmq_socket(
22252235
def zmq_socket_ctx(
22262236
path: str,
22272237
socket_type: Any,
2238+
bind: Optional[bool] = None,
22282239
linger: int = 0,
2240+
identity: Optional[bytes] = None,
22292241
) -> Iterator[zmq.Socket]:
22302242
"""Context manager for a ZMQ socket"""
22312243

22322244
ctx = zmq.Context() # type: ignore[attr-defined]
22332245
try:
2234-
yield make_zmq_socket(ctx, path, socket_type)
2235-
2246+
yield make_zmq_socket(ctx,
2247+
path,
2248+
socket_type,
2249+
bind=bind,
2250+
identity=identity)
22362251
except KeyboardInterrupt:
22372252
logger.debug("Got Keyboard Interrupt.")
22382253

vllm/v1/engine/core.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,11 @@ def __init__(
318318
):
319319
super().__init__(vllm_config, executor_class, log_stats)
320320

321+
self.step_fn = (self.step if self.batch_queue is None else
322+
self.step_with_batch_queue)
323+
324+
self.global_unfinished_reqs = False
325+
321326
# Background Threads and Queues for IO. These enable us to
322327
# overlap ZMQ socket IO with GPU since they release the GIL,
323328
# and to overlap some serialization/deserialization with the
@@ -327,22 +332,16 @@ def __init__(
327332
Any]] = queue.Queue()
328333
self.output_queue: queue.Queue[EngineCoreOutputs] = queue.Queue()
329334
threading.Thread(target=self.process_input_socket,
330-
args=(input_path, ),
335+
args=(input_path, engine_index),
331336
daemon=True).start()
332337
threading.Thread(target=self.process_output_socket,
333338
args=(output_path, engine_index),
334339
daemon=True).start()
335340

336-
self.global_unfinished_reqs = False
337-
338-
self.step_fn = (self.step if self.batch_queue is None else
339-
self.step_with_batch_queue)
340-
341341
@staticmethod
342342
def run_engine_core(*args,
343343
dp_rank: int = 0,
344344
local_dp_rank: int = 0,
345-
ready_pipe,
346345
**kwargs):
347346
"""Launch EngineCore busy loop in background process."""
348347

@@ -377,9 +376,6 @@ def signal_handler(signum, frame):
377376
else:
378377
engine_core = EngineCoreProc(*args, **kwargs)
379378

380-
# Send Readiness signal to EngineClient.
381-
ready_pipe.send({"status": "READY"})
382-
383379
engine_core.run_busy_loop()
384380

385381
except SystemExit:
@@ -476,14 +472,22 @@ def _convert_msgspec_args(method, args):
476472
and not isinstance(v, p.annotation) else v
477473
for v, p in zip(args, arg_types))
478474

479-
def process_input_socket(self, input_path: str):
475+
def process_input_socket(self, input_path: str, engine_index: int):
480476
"""Input socket IO thread."""
481477

482478
# Msgpack serialization decoding.
483479
add_request_decoder = MsgpackDecoder(EngineCoreRequest)
484480
generic_decoder = MsgpackDecoder()
481+
identity = engine_index.to_bytes(length=2, byteorder="little")
482+
483+
with zmq_socket_ctx(input_path,
484+
zmq.DEALER,
485+
identity=identity,
486+
bind=False) as socket:
487+
488+
# Send ready message to front-end once input socket is connected.
489+
socket.send(b'READY')
485490

486-
with zmq_socket_ctx(input_path, zmq.constants.PULL) as socket:
487491
while True:
488492
# (RequestType, RequestData)
489493
type_frame, data_frame = socket.recv_multipart(copy=False)

vllm/v1/engine/core_client.py

Lines changed: 71 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import uuid
99
import weakref
1010
from abc import ABC, abstractmethod
11-
from collections.abc import Awaitable, Sequence
11+
from collections.abc import Awaitable
1212
from concurrent.futures import Future
1313
from dataclasses import dataclass, field
1414
from threading import Thread
@@ -35,6 +35,8 @@
3535

3636
_R = TypeVar('_R') # Return type for collective_rpc
3737

38+
STARTUP_POLL_PERIOD_MS = 10000
39+
3840

3941
class EngineCoreClient(ABC):
4042
"""
@@ -261,15 +263,13 @@ def __init__(
261263
vllm_config: VllmConfig,
262264
executor_class: type[Executor],
263265
log_stats: bool,
264-
ctx: Union[zmq.Context, zmq.asyncio.Context],
266+
input_path: str,
265267
output_path: str,
266268
index: int = 0,
267269
local_dp_rank: int = 0,
268270
):
269-
# Paths and sockets for IPC.
270-
input_path = get_open_zmq_ipc_path()
271-
self.input_socket = make_zmq_socket(ctx, input_path,
272-
zmq.constants.PUSH)
271+
self.index = index
272+
self.identity = index.to_bytes(length=2, byteorder="little")
273273
try:
274274
# Start EngineCore in background process.
275275
self.proc_handle = BackgroundProcHandle(
@@ -291,14 +291,9 @@ def __init__(
291291
# Ensure socket is closed if process fails to start.
292292
self.close()
293293

294-
def send_multipart(self, msg_parts: Sequence):
295-
return self.input_socket.send_multipart(msg_parts, copy=False)
296-
297294
def close(self):
298295
if proc_handle := getattr(self, "proc_handle", None):
299296
proc_handle.shutdown()
300-
if socket := getattr(self, "input_socket", None):
301-
socket.close(linger=0)
302297

303298

304299
@dataclass
@@ -309,6 +304,7 @@ class BackgroundResources:
309304
ctx: Union[zmq.Context]
310305
core_engines: list[CoreEngine] = field(default_factory=list)
311306
output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
307+
input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
312308
shutdown_path: Optional[str] = None
313309

314310
def __call__(self):
@@ -321,6 +317,8 @@ def __call__(self):
321317
# aren't explicitly closed first.
322318
if self.output_socket is not None:
323319
self.output_socket.close(linger=0)
320+
if self.input_socket is not None:
321+
self.input_socket.close(linger=0)
324322
if self.shutdown_path is not None:
325323
# We must ensure that the sync output socket is
326324
# closed cleanly in its own thread.
@@ -387,21 +385,51 @@ def sigusr1_handler(signum, frame):
387385

388386
# Paths and sockets for IPC.
389387
self.output_path = get_open_zmq_ipc_path()
388+
input_path = get_open_zmq_ipc_path()
389+
self.input_socket = make_zmq_socket(self.ctx,
390+
input_path,
391+
zmq.ROUTER,
392+
bind=True)
393+
self.resources.input_socket = self.input_socket
390394

391395
new_core_engine = lambda index, local_dp_rank=None: CoreEngine(
392-
vllm_config, executor_class, log_stats, self.ctx, self.output_path,
393-
index, local_dp_rank)
396+
vllm_config, executor_class, log_stats, input_path, self.
397+
output_path, index, local_dp_rank)
394398

395399
# Start engine core process(es).
396400
self._init_core_engines(vllm_config, new_core_engine,
397401
self.resources.core_engines)
398402

399403
# Wait for engine core process(es) to start.
400-
for engine in self.resources.core_engines:
401-
engine.proc_handle.wait_for_startup()
404+
self._wait_for_engine_startup()
402405

403406
self.utility_results: dict[int, AnyFuture] = {}
404407

408+
def _wait_for_engine_startup(self):
409+
# Get a sync handle to the socket which can be sync or async.
410+
sync_input_socket = zmq.Socket.shadow(self.input_socket)
411+
412+
# Wait for engine core process(es) to send ready messages.
413+
identities = set(eng.index for eng in self.resources.core_engines)
414+
while identities:
415+
while not sync_input_socket.poll(timeout=STARTUP_POLL_PERIOD_MS):
416+
logger.info("Waiting for %d core engine proc(s) to start: %s",
417+
len(identities), identities)
418+
eng_id_bytes, msg = sync_input_socket.recv_multipart()
419+
eng_id = int.from_bytes(eng_id_bytes, byteorder="little")
420+
if eng_id not in identities:
421+
raise RuntimeError(f"Unexpected or duplicate engine: {eng_id}")
422+
if msg != b'READY':
423+
raise RuntimeError(f"Engine {eng_id} failed: {msg.decode()}")
424+
logger.info("Core engine process %d ready.", eng_id)
425+
identities.discard(eng_id)
426+
427+
# Double check that the process are running.
428+
for engine in self.resources.core_engines:
429+
proc = engine.proc_handle.proc
430+
if proc.exitcode is not None:
431+
raise RuntimeError(f"Engine proc {proc.name} not running")
432+
405433
def _init_core_engines(
406434
self,
407435
vllm_config: VllmConfig,
@@ -494,9 +522,10 @@ def get_output(self) -> EngineCoreOutputs:
494522
return self.outputs_queue.get()
495523

496524
def _send_input(self, request_type: EngineCoreRequestType, request: Any):
497-
# (RequestType, SerializedRequest)
498-
msg = (request_type.value, self.encoder.encode(request))
499-
self.core_engine.send_multipart(msg)
525+
# (Identity, RequestType, SerializedRequest)
526+
msg = (self.core_engine.identity, request_type.value,
527+
self.encoder.encode(request))
528+
self.input_socket.send_multipart(msg, copy=False)
500529

501530
def call_utility(self, method: str, *args) -> Any:
502531
call_id = uuid.uuid1().int >> 64
@@ -625,30 +654,34 @@ async def get_output_async(self) -> EngineCoreOutputs:
625654
assert self.outputs_queue is not None
626655
return await self.outputs_queue.get()
627656

628-
async def _send_input(self, request_type: EngineCoreRequestType,
629-
request: Any) -> None:
630-
await self.core_engine.send_multipart(
631-
(request_type.value, self.encoder.encode(request)))
657+
def _send_input(self,
658+
request_type: EngineCoreRequestType,
659+
request: Any,
660+
engine: Optional[CoreEngine] = None) -> Awaitable[None]:
661+
if engine is None:
662+
engine = self.core_engine
632663

633-
self._ensure_output_queue_task()
664+
message = (request_type.value, self.encoder.encode(request))
665+
return self._send_input_message(message, engine)
666+
667+
def _send_input_message(self, message: tuple[bytes, bytes],
668+
engine: CoreEngine) -> Awaitable[None]:
669+
message = (engine.identity, ) + message # type: ignore[assignment]
670+
return self.input_socket.send_multipart(message, copy=False)
634671

635672
async def call_utility_async(self, method: str, *args) -> Any:
636673
return await self._call_utility_async(method,
637674
*args,
638675
engine=self.core_engine)
639676

640-
async def _call_utility_async(
641-
self,
642-
method: str,
643-
*args,
644-
engine: CoreEngine,
645-
) -> Any:
677+
async def _call_utility_async(self, method: str, *args,
678+
engine: CoreEngine) -> Any:
646679
call_id = uuid.uuid1().int >> 64
647680
future = asyncio.get_running_loop().create_future()
648681
self.utility_results[call_id] = future
649682
message = (EngineCoreRequestType.UTILITY.value,
650683
self.encoder.encode((call_id, method, args)))
651-
await engine.send_multipart(message)
684+
await self._send_input_message(message, engine)
652685
self._ensure_output_queue_task()
653686
return await future
654687

@@ -657,6 +690,7 @@ async def add_request_async(self, request: EngineCoreRequest) -> None:
657690
# tokenized.
658691
request.prompt = None
659692
await self._send_input(EngineCoreRequestType.ADD, request)
693+
self._ensure_output_queue_task()
660694

661695
async def abort_requests_async(self, request_ids: list[str]) -> None:
662696
if len(request_ids) > 0:
@@ -761,15 +795,15 @@ async def add_request_async(self, request: EngineCoreRequest) -> None:
761795
self.reqs_in_flight[request.request_id] = chosen_engine
762796
chosen_engine.num_reqs_in_flight += 1
763797
if self.num_engines_running >= len(self.core_engines):
764-
await chosen_engine.send_multipart(msg)
798+
await self._send_input_message(msg, chosen_engine)
765799
else:
766800
# Send request to chosen engine and dp start loop
767801
# control message to all other engines.
768802
self.num_engines_running += len(self.core_engines)
769803
await asyncio.gather(*[
770-
engine.send_multipart(msg if engine is
771-
chosen_engine else self.start_dp_msg)
772-
for engine in self.core_engines
804+
self._send_input_message(
805+
msg if engine is chosen_engine else self.start_dp_msg,
806+
engine) for engine in self.core_engines
773807
])
774808

775809
self._ensure_output_queue_task()
@@ -794,7 +828,7 @@ async def process_engine_outputs(self: "DPAsyncMPClient",
794828
# sure to start the other engines:
795829
self.num_engines_running = len(self.core_engines)
796830
coros = [
797-
engine.send_multipart(self.start_dp_msg)
831+
self._send_input_message(self.start_dp_msg, engine)
798832
for engine in self.core_engines
799833
if not engine.num_reqs_in_flight
800834
]
@@ -820,5 +854,5 @@ async def abort_requests_async(self, request_ids: list[str]) -> None:
820854

821855
async def _abort_requests(self, request_ids: list[str],
822856
engine: CoreEngine) -> None:
823-
await engine.send_multipart((EngineCoreRequestType.ABORT.value,
824-
self.encoder.encode(request_ids)))
857+
await self._send_input(EngineCoreRequestType.ABORT, request_ids,
858+
engine)

vllm/v1/utils.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -105,12 +105,9 @@ def __init__(
105105
process_kwargs: dict[Any, Any],
106106
):
107107
context = get_mp_context()
108-
self.reader, writer = context.Pipe(duplex=False)
109108

110-
assert ("ready_pipe" not in process_kwargs
111-
and "input_path" not in process_kwargs
109+
assert ("input_path" not in process_kwargs
112110
and "output_path" not in process_kwargs)
113-
process_kwargs["ready_pipe"] = writer
114111
process_kwargs["input_path"] = input_path
115112
process_kwargs["output_path"] = output_path
116113

@@ -122,12 +119,6 @@ def __init__(
122119
input_path, output_path)
123120
self.proc.start()
124121

125-
def wait_for_startup(self):
126-
# Wait for startup.
127-
if self.reader.recv()["status"] != "READY":
128-
raise RuntimeError(f"{self.proc.name} initialization failed. "
129-
"See root cause above.")
130-
131122
def shutdown(self):
132123
self._finalizer()
133124

0 commit comments

Comments
 (0)