8
8
import uuid
9
9
import weakref
10
10
from abc import ABC , abstractmethod
11
- from collections .abc import Awaitable , Sequence
11
+ from collections .abc import Awaitable
12
12
from concurrent .futures import Future
13
13
from dataclasses import dataclass , field
14
14
from threading import Thread
@@ -243,15 +243,12 @@ def __init__(
243
243
vllm_config : VllmConfig ,
244
244
executor_class : type [Executor ],
245
245
log_stats : bool ,
246
- ctx : Union [ zmq . Context , zmq . asyncio . Context ] ,
246
+ input_path : str ,
247
247
output_path : str ,
248
248
index : int = 0 ,
249
249
local_dp_rank : int = 0 ,
250
250
):
251
- # Paths and sockets for IPC.
252
- input_path = get_open_zmq_ipc_path ()
253
- self .input_socket = make_zmq_socket (ctx , input_path ,
254
- zmq .constants .PUSH )
251
+ self .identity = index .to_bytes (length = 2 )
255
252
try :
256
253
# Start EngineCore in background process.
257
254
self .proc_handle = BackgroundProcHandle (
@@ -273,14 +270,9 @@ def __init__(
273
270
# Ensure socket is closed if process fails to start.
274
271
self .close ()
275
272
276
- def send_multipart (self , msg_parts : Sequence ):
277
- return self .input_socket .send_multipart (msg_parts , copy = False )
278
-
279
273
def close (self ):
280
274
if proc_handle := getattr (self , "proc_handle" , None ):
281
275
proc_handle .shutdown ()
282
- if socket := getattr (self , "input_socket" , None ):
283
- socket .close (linger = 0 )
284
276
285
277
286
278
@dataclass
@@ -291,6 +283,7 @@ class BackgroundResources:
291
283
ctx : Union [zmq .Context ]
292
284
core_engines : list [CoreEngine ] = field (default_factory = list )
293
285
output_socket : Optional [Union [zmq .Socket , zmq .asyncio .Socket ]] = None
286
+ input_socket : Optional [Union [zmq .Socket , zmq .asyncio .Socket ]] = None
294
287
shutdown_path : Optional [str ] = None
295
288
296
289
def __call__ (self ):
@@ -303,6 +296,8 @@ def __call__(self):
303
296
# aren't explicitly closed first.
304
297
if self .output_socket is not None :
305
298
self .output_socket .close (linger = 0 )
299
+ if self .input_socket is not None :
300
+ self .input_socket .close (linger = 0 )
306
301
if self .shutdown_path is not None :
307
302
# We must ensure that the sync output socket is
308
303
# closed cleanly in its own thread.
@@ -369,10 +364,16 @@ def sigusr1_handler(signum, frame):
369
364
370
365
# Paths and sockets for IPC.
371
366
self .output_path = get_open_zmq_ipc_path ()
367
+ input_path = get_open_zmq_ipc_path ()
368
+ self .input_socket = make_zmq_socket (self .ctx ,
369
+ input_path ,
370
+ zmq .ROUTER ,
371
+ bind = True )
372
+ self .resources .input_socket = self .input_socket
372
373
373
374
new_core_engine = lambda index , local_dp_rank = None : CoreEngine (
374
- vllm_config , executor_class , log_stats , self . ctx , self .output_path ,
375
- index , local_dp_rank )
375
+ vllm_config , executor_class , log_stats , input_path , self .
376
+ output_path , index , local_dp_rank )
376
377
377
378
# Start engine core process(es).
378
379
self ._init_core_engines (vllm_config , new_core_engine ,
@@ -476,9 +477,10 @@ def get_output(self) -> EngineCoreOutputs:
476
477
return self .outputs_queue .get ()
477
478
478
479
def _send_input (self , request_type : EngineCoreRequestType , request : Any ):
479
- # (RequestType, SerializedRequest)
480
- msg = (request_type .value , self .encoder .encode (request ))
481
- self .core_engine .send_multipart (msg )
480
+ # (Identity, RequestType, SerializedRequest)
481
+ msg = (self .core_engine .identity , request_type .value ,
482
+ self .encoder .encode (request ))
483
+ self .input_socket .send_multipart (msg , copy = False )
482
484
483
485
def call_utility (self , method : str , * args ) -> Any :
484
486
call_id = uuid .uuid1 ().int >> 64
@@ -601,30 +603,34 @@ async def get_output_async(self) -> EngineCoreOutputs:
601
603
assert self .outputs_queue is not None
602
604
return await self .outputs_queue .get ()
603
605
604
- async def _send_input (self , request_type : EngineCoreRequestType ,
605
- request : Any ) -> None :
606
- await self .core_engine .send_multipart (
607
- (request_type .value , self .encoder .encode (request )))
606
+ def _send_input (self ,
607
+ request_type : EngineCoreRequestType ,
608
+ request : Any ,
609
+ engine : Optional [CoreEngine ] = None ) -> Awaitable [None ]:
610
+ if engine is None :
611
+ engine = self .core_engine
608
612
609
- self ._ensure_output_queue_task ()
613
+ message = (request_type .value , self .encoder .encode (request ))
614
+ return self ._send_input_message (message , engine )
615
+
616
+ def _send_input_message (self , message : tuple [bytes , bytes ],
617
+ engine : CoreEngine ) -> Awaitable [None ]:
618
+ message = (engine .identity , ) + message # type: ignore[assignment]
619
+ return self .input_socket .send_multipart (message , copy = False )
610
620
611
621
async def call_utility_async (self , method : str , * args ) -> Any :
612
622
return await self ._call_utility_async (method ,
613
623
* args ,
614
624
engine = self .core_engine )
615
625
616
- async def _call_utility_async (
617
- self ,
618
- method : str ,
619
- * args ,
620
- engine : CoreEngine ,
621
- ) -> Any :
626
+ async def _call_utility_async (self , method : str , * args ,
627
+ engine : CoreEngine ) -> Any :
622
628
call_id = uuid .uuid1 ().int >> 64
623
629
future = asyncio .get_running_loop ().create_future ()
624
630
self .utility_results [call_id ] = future
625
631
message = (EngineCoreRequestType .UTILITY .value ,
626
632
self .encoder .encode ((call_id , method , args )))
627
- await engine . send_multipart (message )
633
+ await self . _send_input_message (message , engine )
628
634
self ._ensure_output_queue_task ()
629
635
return await future
630
636
@@ -633,6 +639,7 @@ async def add_request_async(self, request: EngineCoreRequest) -> None:
633
639
# tokenized.
634
640
request .prompt = None
635
641
await self ._send_input (EngineCoreRequestType .ADD , request )
642
+ self ._ensure_output_queue_task ()
636
643
637
644
async def abort_requests_async (self , request_ids : list [str ]) -> None :
638
645
if len (request_ids ) > 0 :
@@ -730,15 +737,15 @@ async def add_request_async(self, request: EngineCoreRequest) -> None:
730
737
self .reqs_in_flight [request .request_id ] = chosen_engine
731
738
chosen_engine .num_reqs_in_flight += 1
732
739
if self .num_engines_running >= len (self .core_engines ):
733
- await chosen_engine . send_multipart (msg )
740
+ await self . _send_input_message (msg , chosen_engine )
734
741
else :
735
742
# Send request to chosen engine and dp start loop
736
743
# control message to all other engines.
737
744
self .num_engines_running += len (self .core_engines )
738
745
await asyncio .gather (* [
739
- engine . send_multipart ( msg if engine is
740
- chosen_engine else self .start_dp_msg )
741
- for engine in self .core_engines
746
+ self . _send_input_message (
747
+ msg if engine is chosen_engine else self .start_dp_msg ,
748
+ engine ) for engine in self .core_engines
742
749
])
743
750
744
751
self ._ensure_output_queue_task ()
@@ -763,7 +770,7 @@ async def process_engine_outputs(self: "DPAsyncMPClient",
763
770
# sure to start the other engines:
764
771
self .num_engines_running = len (self .core_engines )
765
772
coros = [
766
- engine . send_multipart (self .start_dp_msg )
773
+ self . _send_input_message (self .start_dp_msg , engine )
767
774
for engine in self .core_engines
768
775
if not engine .num_reqs_in_flight
769
776
]
@@ -789,5 +796,5 @@ async def abort_requests_async(self, request_ids: list[str]) -> None:
789
796
790
797
async def _abort_requests (self , request_ids : list [str ],
791
798
engine : CoreEngine ) -> None :
792
- await engine . send_multipart (( EngineCoreRequestType .ABORT . value ,
793
- self . encoder . encode ( request_ids )) )
799
+ await self . _send_input ( EngineCoreRequestType .ABORT , request_ids ,
800
+ engine )
0 commit comments