Skip to content

Commit b75080f

Browse files
committed
[V1] DP scale-out (1/N): Use zmq ROUTER/DEALER socket types for input queue
So that a single queue / socket address can be used for all engines. Signed-off-by: Nick Hill <[email protected]>
1 parent e6e3c55 commit b75080f

File tree

3 files changed

+74
-48
lines changed

3 files changed

+74
-48
lines changed

vllm/utils.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2171,6 +2171,8 @@ def make_zmq_socket(
21712171
ctx: Union[zmq.asyncio.Context, zmq.Context], # type: ignore[name-defined]
21722172
path: str,
21732173
socket_type: Any,
2174+
bind: Optional[bool] = None,
2175+
identity: Optional[bytes] = None,
21742176
) -> Union[zmq.Socket, zmq.asyncio.Socket]: # type: ignore[name-defined]
21752177
"""Make a ZMQ socket with the proper bind/connect semantics."""
21762178

@@ -2189,16 +2191,24 @@ def make_zmq_socket(
21892191
else:
21902192
buf_size = -1 # Use system default buffer size
21912193

2192-
if socket_type == zmq.constants.PULL:
2193-
socket.setsockopt(zmq.constants.RCVHWM, 0)
2194-
socket.setsockopt(zmq.constants.RCVBUF, buf_size)
2194+
if bind is None:
2195+
bind = socket_type != zmq.PUSH
2196+
2197+
if socket_type in (zmq.PULL, zmq.DEALER, zmq.ROUTER):
2198+
socket.setsockopt(zmq.RCVHWM, 0)
2199+
socket.setsockopt(zmq.RCVBUF, buf_size)
2200+
2201+
if socket_type in (zmq.PUSH, zmq.DEALER, zmq.ROUTER):
2202+
socket.setsockopt(zmq.SNDHWM, 0)
2203+
socket.setsockopt(zmq.SNDBUF, buf_size)
2204+
2205+
if identity is not None:
2206+
socket.setsockopt(zmq.IDENTITY, identity)
2207+
2208+
if bind:
21952209
socket.bind(path)
2196-
elif socket_type == zmq.constants.PUSH:
2197-
socket.setsockopt(zmq.constants.SNDHWM, 0)
2198-
socket.setsockopt(zmq.constants.SNDBUF, buf_size)
2199-
socket.connect(path)
22002210
else:
2201-
raise ValueError(f"Unknown Socket Type: {socket_type}")
2211+
socket.connect(path)
22022212

22032213
return socket
22042214

@@ -2207,14 +2217,19 @@ def make_zmq_socket(
22072217
def zmq_socket_ctx(
22082218
path: str,
22092219
socket_type: Any,
2220+
bind: Optional[bool] = None,
22102221
linger: int = 0,
2222+
identity: Optional[bytes] = None,
22112223
) -> Iterator[zmq.Socket]:
22122224
"""Context manager for a ZMQ socket"""
22132225

22142226
ctx = zmq.Context() # type: ignore[attr-defined]
22152227
try:
2216-
yield make_zmq_socket(ctx, path, socket_type)
2217-
2228+
yield make_zmq_socket(ctx,
2229+
path,
2230+
socket_type,
2231+
bind=bind,
2232+
identity=identity)
22182233
except KeyboardInterrupt:
22192234
logger.debug("Got Keyboard Interrupt.")
22202235

vllm/v1/engine/core.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ def __init__(
313313
Any]] = queue.Queue()
314314
self.output_queue: queue.Queue[EngineCoreOutputs] = queue.Queue()
315315
threading.Thread(target=self.process_input_socket,
316-
args=(input_path, ),
316+
args=(input_path, engine_index),
317317
daemon=True).start()
318318
threading.Thread(target=self.process_output_socket,
319319
args=(output_path, engine_index),
@@ -462,14 +462,18 @@ def _convert_msgspec_args(method, args):
462462
and not isinstance(v, p.annotation) else v
463463
for v, p in zip(args, arg_types))
464464

465-
def process_input_socket(self, input_path: str):
465+
def process_input_socket(self, input_path: str, engine_index: int):
466466
"""Input socket IO thread."""
467467

468468
# Msgpack serialization decoding.
469469
add_request_decoder = MsgpackDecoder(EngineCoreRequest)
470470
generic_decoder = MsgpackDecoder()
471+
identity = engine_index.to_bytes(length=2)
471472

472-
with zmq_socket_ctx(input_path, zmq.constants.PULL) as socket:
473+
with zmq_socket_ctx(input_path,
474+
zmq.DEALER,
475+
identity=identity,
476+
bind=False) as socket:
473477
while True:
474478
# (RequestType, RequestData)
475479
type_frame, data_frame = socket.recv_multipart(copy=False)

vllm/v1/engine/core_client.py

Lines changed: 42 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import uuid
99
import weakref
1010
from abc import ABC, abstractmethod
11-
from collections.abc import Awaitable, Sequence
11+
from collections.abc import Awaitable
1212
from concurrent.futures import Future
1313
from dataclasses import dataclass, field
1414
from threading import Thread
@@ -243,15 +243,12 @@ def __init__(
243243
vllm_config: VllmConfig,
244244
executor_class: type[Executor],
245245
log_stats: bool,
246-
ctx: Union[zmq.Context, zmq.asyncio.Context],
246+
input_path: str,
247247
output_path: str,
248248
index: int = 0,
249249
local_dp_rank: int = 0,
250250
):
251-
# Paths and sockets for IPC.
252-
input_path = get_open_zmq_ipc_path()
253-
self.input_socket = make_zmq_socket(ctx, input_path,
254-
zmq.constants.PUSH)
251+
self.identity = index.to_bytes(length=2)
255252
try:
256253
# Start EngineCore in background process.
257254
self.proc_handle = BackgroundProcHandle(
@@ -273,14 +270,9 @@ def __init__(
273270
# Ensure socket is closed if process fails to start.
274271
self.close()
275272

276-
def send_multipart(self, msg_parts: Sequence):
277-
return self.input_socket.send_multipart(msg_parts, copy=False)
278-
279273
def close(self):
280274
if proc_handle := getattr(self, "proc_handle", None):
281275
proc_handle.shutdown()
282-
if socket := getattr(self, "input_socket", None):
283-
socket.close(linger=0)
284276

285277

286278
@dataclass
@@ -291,6 +283,7 @@ class BackgroundResources:
291283
ctx: Union[zmq.Context]
292284
core_engines: list[CoreEngine] = field(default_factory=list)
293285
output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
286+
input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
294287
shutdown_path: Optional[str] = None
295288

296289
def __call__(self):
@@ -303,6 +296,8 @@ def __call__(self):
303296
# aren't explicitly closed first.
304297
if self.output_socket is not None:
305298
self.output_socket.close(linger=0)
299+
if self.input_socket is not None:
300+
self.input_socket.close(linger=0)
306301
if self.shutdown_path is not None:
307302
# We must ensure that the sync output socket is
308303
# closed cleanly in its own thread.
@@ -369,10 +364,16 @@ def sigusr1_handler(signum, frame):
369364

370365
# Paths and sockets for IPC.
371366
self.output_path = get_open_zmq_ipc_path()
367+
input_path = get_open_zmq_ipc_path()
368+
self.input_socket = make_zmq_socket(self.ctx,
369+
input_path,
370+
zmq.ROUTER,
371+
bind=True)
372+
self.resources.input_socket = self.input_socket
372373

373374
new_core_engine = lambda index, local_dp_rank=None: CoreEngine(
374-
vllm_config, executor_class, log_stats, self.ctx, self.output_path,
375-
index, local_dp_rank)
375+
vllm_config, executor_class, log_stats, input_path, self.
376+
output_path, index, local_dp_rank)
376377

377378
# Start engine core process(es).
378379
self._init_core_engines(vllm_config, new_core_engine,
@@ -476,9 +477,10 @@ def get_output(self) -> EngineCoreOutputs:
476477
return self.outputs_queue.get()
477478

478479
def _send_input(self, request_type: EngineCoreRequestType, request: Any):
479-
# (RequestType, SerializedRequest)
480-
msg = (request_type.value, self.encoder.encode(request))
481-
self.core_engine.send_multipart(msg)
480+
# (Identity, RequestType, SerializedRequest)
481+
msg = (self.core_engine.identity, request_type.value,
482+
self.encoder.encode(request))
483+
self.input_socket.send_multipart(msg, copy=False)
482484

483485
def call_utility(self, method: str, *args) -> Any:
484486
call_id = uuid.uuid1().int >> 64
@@ -601,30 +603,34 @@ async def get_output_async(self) -> EngineCoreOutputs:
601603
assert self.outputs_queue is not None
602604
return await self.outputs_queue.get()
603605

604-
async def _send_input(self, request_type: EngineCoreRequestType,
605-
request: Any) -> None:
606-
await self.core_engine.send_multipart(
607-
(request_type.value, self.encoder.encode(request)))
606+
def _send_input(self,
607+
request_type: EngineCoreRequestType,
608+
request: Any,
609+
engine: Optional[CoreEngine] = None) -> Awaitable[None]:
610+
if engine is None:
611+
engine = self.core_engine
608612

609-
self._ensure_output_queue_task()
613+
message = (request_type.value, self.encoder.encode(request))
614+
return self._send_input_message(message, engine)
615+
616+
def _send_input_message(self, message: tuple[bytes, bytes],
617+
engine: CoreEngine) -> Awaitable[None]:
618+
message = (engine.identity, ) + message # type: ignore[assignment]
619+
return self.input_socket.send_multipart(message, copy=False)
610620

611621
async def call_utility_async(self, method: str, *args) -> Any:
612622
return await self._call_utility_async(method,
613623
*args,
614624
engine=self.core_engine)
615625

616-
async def _call_utility_async(
617-
self,
618-
method: str,
619-
*args,
620-
engine: CoreEngine,
621-
) -> Any:
626+
async def _call_utility_async(self, method: str, *args,
627+
engine: CoreEngine) -> Any:
622628
call_id = uuid.uuid1().int >> 64
623629
future = asyncio.get_running_loop().create_future()
624630
self.utility_results[call_id] = future
625631
message = (EngineCoreRequestType.UTILITY.value,
626632
self.encoder.encode((call_id, method, args)))
627-
await engine.send_multipart(message)
633+
await self._send_input_message(message, engine)
628634
self._ensure_output_queue_task()
629635
return await future
630636

@@ -633,6 +639,7 @@ async def add_request_async(self, request: EngineCoreRequest) -> None:
633639
# tokenized.
634640
request.prompt = None
635641
await self._send_input(EngineCoreRequestType.ADD, request)
642+
self._ensure_output_queue_task()
636643

637644
async def abort_requests_async(self, request_ids: list[str]) -> None:
638645
if len(request_ids) > 0:
@@ -730,15 +737,15 @@ async def add_request_async(self, request: EngineCoreRequest) -> None:
730737
self.reqs_in_flight[request.request_id] = chosen_engine
731738
chosen_engine.num_reqs_in_flight += 1
732739
if self.num_engines_running >= len(self.core_engines):
733-
await chosen_engine.send_multipart(msg)
740+
await self._send_input_message(msg, chosen_engine)
734741
else:
735742
# Send request to chosen engine and dp start loop
736743
# control message to all other engines.
737744
self.num_engines_running += len(self.core_engines)
738745
await asyncio.gather(*[
739-
engine.send_multipart(msg if engine is
740-
chosen_engine else self.start_dp_msg)
741-
for engine in self.core_engines
746+
self._send_input_message(
747+
msg if engine is chosen_engine else self.start_dp_msg,
748+
engine) for engine in self.core_engines
742749
])
743750

744751
self._ensure_output_queue_task()
@@ -763,7 +770,7 @@ async def process_engine_outputs(self: "DPAsyncMPClient",
763770
# sure to start the other engines:
764771
self.num_engines_running = len(self.core_engines)
765772
coros = [
766-
engine.send_multipart(self.start_dp_msg)
773+
self._send_input_message(self.start_dp_msg, engine)
767774
for engine in self.core_engines
768775
if not engine.num_reqs_in_flight
769776
]
@@ -789,5 +796,5 @@ async def abort_requests_async(self, request_ids: list[str]) -> None:
789796

790797
async def _abort_requests(self, request_ids: list[str],
791798
engine: CoreEngine) -> None:
792-
await engine.send_multipart((EngineCoreRequestType.ABORT.value,
793-
self.encoder.encode(request_ids)))
799+
await self._send_input(EngineCoreRequestType.ABORT, request_ids,
800+
engine)

0 commit comments

Comments
 (0)