Skip to content

Commit d4a3dbe

Browse files
simon-moshreyankg
authored andcommitted
Revert "[V1] DP scale-out (1/N): Use zmq ROUTER/DEALER sockets for input queue (vllm-project#15906)"
This reverts commit 651cf0f.
1 parent ceb68de commit d4a3dbe

File tree

4 files changed

+69
-113
lines changed

4 files changed

+69
-113
lines changed

vllm/utils.py

Lines changed: 10 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2189,8 +2189,6 @@ def make_zmq_socket(
21892189
ctx: Union[zmq.asyncio.Context, zmq.Context], # type: ignore[name-defined]
21902190
path: str,
21912191
socket_type: Any,
2192-
bind: Optional[bool] = None,
2193-
identity: Optional[bytes] = None,
21942192
) -> Union[zmq.Socket, zmq.asyncio.Socket]: # type: ignore[name-defined]
21952193
"""Make a ZMQ socket with the proper bind/connect semantics."""
21962194

@@ -2209,24 +2207,16 @@ def make_zmq_socket(
22092207
else:
22102208
buf_size = -1 # Use system default buffer size
22112209

2212-
if bind is None:
2213-
bind = socket_type != zmq.PUSH
2214-
2215-
if socket_type in (zmq.PULL, zmq.DEALER, zmq.ROUTER):
2216-
socket.setsockopt(zmq.RCVHWM, 0)
2217-
socket.setsockopt(zmq.RCVBUF, buf_size)
2218-
2219-
if socket_type in (zmq.PUSH, zmq.DEALER, zmq.ROUTER):
2220-
socket.setsockopt(zmq.SNDHWM, 0)
2221-
socket.setsockopt(zmq.SNDBUF, buf_size)
2222-
2223-
if identity is not None:
2224-
socket.setsockopt(zmq.IDENTITY, identity)
2225-
2226-
if bind:
2210+
if socket_type == zmq.constants.PULL:
2211+
socket.setsockopt(zmq.constants.RCVHWM, 0)
2212+
socket.setsockopt(zmq.constants.RCVBUF, buf_size)
22272213
socket.bind(path)
2228-
else:
2214+
elif socket_type == zmq.constants.PUSH:
2215+
socket.setsockopt(zmq.constants.SNDHWM, 0)
2216+
socket.setsockopt(zmq.constants.SNDBUF, buf_size)
22292217
socket.connect(path)
2218+
else:
2219+
raise ValueError(f"Unknown Socket Type: {socket_type}")
22302220

22312221
return socket
22322222

@@ -2235,19 +2225,14 @@ def make_zmq_socket(
22352225
def zmq_socket_ctx(
22362226
path: str,
22372227
socket_type: Any,
2238-
bind: Optional[bool] = None,
22392228
linger: int = 0,
2240-
identity: Optional[bytes] = None,
22412229
) -> Iterator[zmq.Socket]:
22422230
"""Context manager for a ZMQ socket"""
22432231

22442232
ctx = zmq.Context() # type: ignore[attr-defined]
22452233
try:
2246-
yield make_zmq_socket(ctx,
2247-
path,
2248-
socket_type,
2249-
bind=bind,
2250-
identity=identity)
2234+
yield make_zmq_socket(ctx, path, socket_type)
2235+
22512236
except KeyboardInterrupt:
22522237
logger.debug("Got Keyboard Interrupt.")
22532238

vllm/v1/engine/core.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -318,11 +318,6 @@ def __init__(
318318
):
319319
super().__init__(vllm_config, executor_class, log_stats)
320320

321-
self.step_fn = (self.step if self.batch_queue is None else
322-
self.step_with_batch_queue)
323-
324-
self.global_unfinished_reqs = False
325-
326321
# Background Threads and Queues for IO. These enable us to
327322
# overlap ZMQ socket IO with GPU since they release the GIL,
328323
# and to overlap some serialization/deserialization with the
@@ -332,16 +327,22 @@ def __init__(
332327
Any]] = queue.Queue()
333328
self.output_queue: queue.Queue[EngineCoreOutputs] = queue.Queue()
334329
threading.Thread(target=self.process_input_socket,
335-
args=(input_path, engine_index),
330+
args=(input_path, ),
336331
daemon=True).start()
337332
threading.Thread(target=self.process_output_socket,
338333
args=(output_path, engine_index),
339334
daemon=True).start()
340335

336+
self.global_unfinished_reqs = False
337+
338+
self.step_fn = (self.step if self.batch_queue is None else
339+
self.step_with_batch_queue)
340+
341341
@staticmethod
342342
def run_engine_core(*args,
343343
dp_rank: int = 0,
344344
local_dp_rank: int = 0,
345+
ready_pipe,
345346
**kwargs):
346347
"""Launch EngineCore busy loop in background process."""
347348

@@ -376,6 +377,9 @@ def signal_handler(signum, frame):
376377
else:
377378
engine_core = EngineCoreProc(*args, **kwargs)
378379

380+
# Send Readiness signal to EngineClient.
381+
ready_pipe.send({"status": "READY"})
382+
379383
engine_core.run_busy_loop()
380384

381385
except SystemExit:
@@ -472,22 +476,14 @@ def _convert_msgspec_args(method, args):
472476
and not isinstance(v, p.annotation) else v
473477
for v, p in zip(args, arg_types))
474478

475-
def process_input_socket(self, input_path: str, engine_index: int):
479+
def process_input_socket(self, input_path: str):
476480
"""Input socket IO thread."""
477481

478482
# Msgpack serialization decoding.
479483
add_request_decoder = MsgpackDecoder(EngineCoreRequest)
480484
generic_decoder = MsgpackDecoder()
481-
identity = engine_index.to_bytes(length=2, byteorder="little")
482-
483-
with zmq_socket_ctx(input_path,
484-
zmq.DEALER,
485-
identity=identity,
486-
bind=False) as socket:
487-
488-
# Send ready message to front-end once input socket is connected.
489-
socket.send(b'READY')
490485

486+
with zmq_socket_ctx(input_path, zmq.constants.PULL) as socket:
491487
while True:
492488
# (RequestType, RequestData)
493489
type_frame, data_frame = socket.recv_multipart(copy=False)

vllm/v1/engine/core_client.py

Lines changed: 37 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import uuid
99
import weakref
1010
from abc import ABC, abstractmethod
11-
from collections.abc import Awaitable
11+
from collections.abc import Awaitable, Sequence
1212
from concurrent.futures import Future
1313
from dataclasses import dataclass, field
1414
from threading import Thread
@@ -35,8 +35,6 @@
3535

3636
_R = TypeVar('_R') # Return type for collective_rpc
3737

38-
STARTUP_POLL_PERIOD_MS = 10000
39-
4038

4139
class EngineCoreClient(ABC):
4240
"""
@@ -263,13 +261,15 @@ def __init__(
263261
vllm_config: VllmConfig,
264262
executor_class: type[Executor],
265263
log_stats: bool,
266-
input_path: str,
264+
ctx: Union[zmq.Context, zmq.asyncio.Context],
267265
output_path: str,
268266
index: int = 0,
269267
local_dp_rank: int = 0,
270268
):
271-
self.index = index
272-
self.identity = index.to_bytes(length=2, byteorder="little")
269+
# Paths and sockets for IPC.
270+
input_path = get_open_zmq_ipc_path()
271+
self.input_socket = make_zmq_socket(ctx, input_path,
272+
zmq.constants.PUSH)
273273
try:
274274
# Start EngineCore in background process.
275275
self.proc_handle = BackgroundProcHandle(
@@ -291,9 +291,14 @@ def __init__(
291291
# Ensure socket is closed if process fails to start.
292292
self.close()
293293

294+
def send_multipart(self, msg_parts: Sequence):
295+
return self.input_socket.send_multipart(msg_parts, copy=False)
296+
294297
def close(self):
295298
if proc_handle := getattr(self, "proc_handle", None):
296299
proc_handle.shutdown()
300+
if socket := getattr(self, "input_socket", None):
301+
socket.close(linger=0)
297302

298303

299304
@dataclass
@@ -304,7 +309,6 @@ class BackgroundResources:
304309
ctx: Union[zmq.Context]
305310
core_engines: list[CoreEngine] = field(default_factory=list)
306311
output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
307-
input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
308312
shutdown_path: Optional[str] = None
309313

310314
def __call__(self):
@@ -317,8 +321,6 @@ def __call__(self):
317321
# aren't explicitly closed first.
318322
if self.output_socket is not None:
319323
self.output_socket.close(linger=0)
320-
if self.input_socket is not None:
321-
self.input_socket.close(linger=0)
322324
if self.shutdown_path is not None:
323325
# We must ensure that the sync output socket is
324326
# closed cleanly in its own thread.
@@ -385,51 +387,21 @@ def sigusr1_handler(signum, frame):
385387

386388
# Paths and sockets for IPC.
387389
self.output_path = get_open_zmq_ipc_path()
388-
input_path = get_open_zmq_ipc_path()
389-
self.input_socket = make_zmq_socket(self.ctx,
390-
input_path,
391-
zmq.ROUTER,
392-
bind=True)
393-
self.resources.input_socket = self.input_socket
394390

395391
new_core_engine = lambda index, local_dp_rank=None: CoreEngine(
396-
vllm_config, executor_class, log_stats, input_path, self.
397-
output_path, index, local_dp_rank)
392+
vllm_config, executor_class, log_stats, self.ctx, self.output_path,
393+
index, local_dp_rank)
398394

399395
# Start engine core process(es).
400396
self._init_core_engines(vllm_config, new_core_engine,
401397
self.resources.core_engines)
402398

403399
# Wait for engine core process(es) to start.
404-
self._wait_for_engine_startup()
400+
for engine in self.resources.core_engines:
401+
engine.proc_handle.wait_for_startup()
405402

406403
self.utility_results: dict[int, AnyFuture] = {}
407404

408-
def _wait_for_engine_startup(self):
409-
# Get a sync handle to the socket which can be sync or async.
410-
sync_input_socket = zmq.Socket.shadow(self.input_socket)
411-
412-
# Wait for engine core process(es) to send ready messages.
413-
identities = set(eng.index for eng in self.resources.core_engines)
414-
while identities:
415-
while not sync_input_socket.poll(timeout=STARTUP_POLL_PERIOD_MS):
416-
logger.info("Waiting for %d core engine proc(s) to start: %s",
417-
len(identities), identities)
418-
eng_id_bytes, msg = sync_input_socket.recv_multipart()
419-
eng_id = int.from_bytes(eng_id_bytes, byteorder="little")
420-
if eng_id not in identities:
421-
raise RuntimeError(f"Unexpected or duplicate engine: {eng_id}")
422-
if msg != b'READY':
423-
raise RuntimeError(f"Engine {eng_id} failed: {msg.decode()}")
424-
logger.info("Core engine process %d ready.", eng_id)
425-
identities.discard(eng_id)
426-
427-
# Double check that the process are running.
428-
for engine in self.resources.core_engines:
429-
proc = engine.proc_handle.proc
430-
if proc.exitcode is not None:
431-
raise RuntimeError(f"Engine proc {proc.name} not running")
432-
433405
def _init_core_engines(
434406
self,
435407
vllm_config: VllmConfig,
@@ -522,10 +494,9 @@ def get_output(self) -> EngineCoreOutputs:
522494
return self.outputs_queue.get()
523495

524496
def _send_input(self, request_type: EngineCoreRequestType, request: Any):
525-
# (Identity, RequestType, SerializedRequest)
526-
msg = (self.core_engine.identity, request_type.value,
527-
self.encoder.encode(request))
528-
self.input_socket.send_multipart(msg, copy=False)
497+
# (RequestType, SerializedRequest)
498+
msg = (request_type.value, self.encoder.encode(request))
499+
self.core_engine.send_multipart(msg)
529500

530501
def call_utility(self, method: str, *args) -> Any:
531502
call_id = uuid.uuid1().int >> 64
@@ -654,34 +625,30 @@ async def get_output_async(self) -> EngineCoreOutputs:
654625
assert self.outputs_queue is not None
655626
return await self.outputs_queue.get()
656627

657-
def _send_input(self,
658-
request_type: EngineCoreRequestType,
659-
request: Any,
660-
engine: Optional[CoreEngine] = None) -> Awaitable[None]:
661-
if engine is None:
662-
engine = self.core_engine
663-
664-
message = (request_type.value, self.encoder.encode(request))
665-
return self._send_input_message(message, engine)
628+
async def _send_input(self, request_type: EngineCoreRequestType,
629+
request: Any) -> None:
630+
await self.core_engine.send_multipart(
631+
(request_type.value, self.encoder.encode(request)))
666632

667-
def _send_input_message(self, message: tuple[bytes, bytes],
668-
engine: CoreEngine) -> Awaitable[None]:
669-
message = (engine.identity, ) + message # type: ignore[assignment]
670-
return self.input_socket.send_multipart(message, copy=False)
633+
self._ensure_output_queue_task()
671634

672635
async def call_utility_async(self, method: str, *args) -> Any:
673636
return await self._call_utility_async(method,
674637
*args,
675638
engine=self.core_engine)
676639

677-
async def _call_utility_async(self, method: str, *args,
678-
engine: CoreEngine) -> Any:
640+
async def _call_utility_async(
641+
self,
642+
method: str,
643+
*args,
644+
engine: CoreEngine,
645+
) -> Any:
679646
call_id = uuid.uuid1().int >> 64
680647
future = asyncio.get_running_loop().create_future()
681648
self.utility_results[call_id] = future
682649
message = (EngineCoreRequestType.UTILITY.value,
683650
self.encoder.encode((call_id, method, args)))
684-
await self._send_input_message(message, engine)
651+
await engine.send_multipart(message)
685652
self._ensure_output_queue_task()
686653
return await future
687654

@@ -690,7 +657,6 @@ async def add_request_async(self, request: EngineCoreRequest) -> None:
690657
# tokenized.
691658
request.prompt = None
692659
await self._send_input(EngineCoreRequestType.ADD, request)
693-
self._ensure_output_queue_task()
694660

695661
async def abort_requests_async(self, request_ids: list[str]) -> None:
696662
if len(request_ids) > 0:
@@ -795,15 +761,15 @@ async def add_request_async(self, request: EngineCoreRequest) -> None:
795761
self.reqs_in_flight[request.request_id] = chosen_engine
796762
chosen_engine.num_reqs_in_flight += 1
797763
if self.num_engines_running >= len(self.core_engines):
798-
await self._send_input_message(msg, chosen_engine)
764+
await chosen_engine.send_multipart(msg)
799765
else:
800766
# Send request to chosen engine and dp start loop
801767
# control message to all other engines.
802768
self.num_engines_running += len(self.core_engines)
803769
await asyncio.gather(*[
804-
self._send_input_message(
805-
msg if engine is chosen_engine else self.start_dp_msg,
806-
engine) for engine in self.core_engines
770+
engine.send_multipart(msg if engine is
771+
chosen_engine else self.start_dp_msg)
772+
for engine in self.core_engines
807773
])
808774

809775
self._ensure_output_queue_task()
@@ -828,7 +794,7 @@ async def process_engine_outputs(self: "DPAsyncMPClient",
828794
# sure to start the other engines:
829795
self.num_engines_running = len(self.core_engines)
830796
coros = [
831-
self._send_input_message(self.start_dp_msg, engine)
797+
engine.send_multipart(self.start_dp_msg)
832798
for engine in self.core_engines
833799
if not engine.num_reqs_in_flight
834800
]
@@ -854,5 +820,5 @@ async def abort_requests_async(self, request_ids: list[str]) -> None:
854820

855821
async def _abort_requests(self, request_ids: list[str],
856822
engine: CoreEngine) -> None:
857-
await self._send_input(EngineCoreRequestType.ABORT, request_ids,
858-
engine)
823+
await engine.send_multipart((EngineCoreRequestType.ABORT.value,
824+
self.encoder.encode(request_ids)))

vllm/v1/utils.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,12 @@ def __init__(
105105
process_kwargs: dict[Any, Any],
106106
):
107107
context = get_mp_context()
108+
self.reader, writer = context.Pipe(duplex=False)
108109

109-
assert ("input_path" not in process_kwargs
110+
assert ("ready_pipe" not in process_kwargs
111+
and "input_path" not in process_kwargs
110112
and "output_path" not in process_kwargs)
113+
process_kwargs["ready_pipe"] = writer
111114
process_kwargs["input_path"] = input_path
112115
process_kwargs["output_path"] = output_path
113116

@@ -119,6 +122,12 @@ def __init__(
119122
input_path, output_path)
120123
self.proc.start()
121124

125+
def wait_for_startup(self):
126+
# Wait for startup.
127+
if self.reader.recv()["status"] != "READY":
128+
raise RuntimeError(f"{self.proc.name} initialization failed. "
129+
"See root cause above.")
130+
122131
def shutdown(self):
123132
self._finalizer()
124133

0 commit comments

Comments
 (0)