8
8
import uuid
9
9
import weakref
10
10
from abc import ABC , abstractmethod
11
- from collections .abc import Awaitable , Sequence
11
+ from collections .abc import Awaitable
12
12
from concurrent .futures import Future
13
13
from dataclasses import dataclass , field
14
14
from threading import Thread
35
35
36
36
_R = TypeVar ('_R' ) # Return type for collective_rpc
37
37
38
+ STARTUP_POLL_PERIOD_MS = 10000
39
+
38
40
39
41
class EngineCoreClient (ABC ):
40
42
"""
@@ -261,15 +263,13 @@ def __init__(
261
263
vllm_config : VllmConfig ,
262
264
executor_class : type [Executor ],
263
265
log_stats : bool ,
264
- ctx : Union [ zmq . Context , zmq . asyncio . Context ] ,
266
+ input_path : str ,
265
267
output_path : str ,
266
268
index : int = 0 ,
267
269
local_dp_rank : int = 0 ,
268
270
):
269
- # Paths and sockets for IPC.
270
- input_path = get_open_zmq_ipc_path ()
271
- self .input_socket = make_zmq_socket (ctx , input_path ,
272
- zmq .constants .PUSH )
271
+ self .index = index
272
+ self .identity = index .to_bytes (length = 2 , byteorder = "little" )
273
273
try :
274
274
# Start EngineCore in background process.
275
275
self .proc_handle = BackgroundProcHandle (
@@ -291,14 +291,9 @@ def __init__(
291
291
# Ensure socket is closed if process fails to start.
292
292
self .close ()
293
293
294
- def send_multipart (self , msg_parts : Sequence ):
295
- return self .input_socket .send_multipart (msg_parts , copy = False )
296
-
297
294
def close (self ):
298
295
if proc_handle := getattr (self , "proc_handle" , None ):
299
296
proc_handle .shutdown ()
300
- if socket := getattr (self , "input_socket" , None ):
301
- socket .close (linger = 0 )
302
297
303
298
304
299
@dataclass
@@ -309,6 +304,7 @@ class BackgroundResources:
309
304
ctx : Union [zmq .Context ]
310
305
core_engines : list [CoreEngine ] = field (default_factory = list )
311
306
output_socket : Optional [Union [zmq .Socket , zmq .asyncio .Socket ]] = None
307
+ input_socket : Optional [Union [zmq .Socket , zmq .asyncio .Socket ]] = None
312
308
shutdown_path : Optional [str ] = None
313
309
314
310
def __call__ (self ):
@@ -321,6 +317,8 @@ def __call__(self):
321
317
# aren't explicitly closed first.
322
318
if self .output_socket is not None :
323
319
self .output_socket .close (linger = 0 )
320
+ if self .input_socket is not None :
321
+ self .input_socket .close (linger = 0 )
324
322
if self .shutdown_path is not None :
325
323
# We must ensure that the sync output socket is
326
324
# closed cleanly in its own thread.
@@ -387,21 +385,51 @@ def sigusr1_handler(signum, frame):
387
385
388
386
# Paths and sockets for IPC.
389
387
self .output_path = get_open_zmq_ipc_path ()
388
+ input_path = get_open_zmq_ipc_path ()
389
+ self .input_socket = make_zmq_socket (self .ctx ,
390
+ input_path ,
391
+ zmq .ROUTER ,
392
+ bind = True )
393
+ self .resources .input_socket = self .input_socket
390
394
391
395
new_core_engine = lambda index , local_dp_rank = None : CoreEngine (
392
- vllm_config , executor_class , log_stats , self . ctx , self .output_path ,
393
- index , local_dp_rank )
396
+ vllm_config , executor_class , log_stats , input_path , self .
397
+ output_path , index , local_dp_rank )
394
398
395
399
# Start engine core process(es).
396
400
self ._init_core_engines (vllm_config , new_core_engine ,
397
401
self .resources .core_engines )
398
402
399
403
# Wait for engine core process(es) to start.
400
- for engine in self .resources .core_engines :
401
- engine .proc_handle .wait_for_startup ()
404
+ self ._wait_for_engine_startup ()
402
405
403
406
self .utility_results : dict [int , AnyFuture ] = {}
404
407
408
+ def _wait_for_engine_startup (self ):
409
+ # Get a sync handle to the socket which can be sync or async.
410
+ sync_input_socket = zmq .Socket .shadow (self .input_socket )
411
+
412
+ # Wait for engine core process(es) to send ready messages.
413
+ identities = set (eng .index for eng in self .resources .core_engines )
414
+ while identities :
415
+ while not sync_input_socket .poll (timeout = STARTUP_POLL_PERIOD_MS ):
416
+ logger .info ("Waiting for %d core engine proc(s) to start: %s" ,
417
+ len (identities ), identities )
418
+ eng_id_bytes , msg = sync_input_socket .recv_multipart ()
419
+ eng_id = int .from_bytes (eng_id_bytes , byteorder = "little" )
420
+ if eng_id not in identities :
421
+ raise RuntimeError (f"Unexpected or duplicate engine: { eng_id } " )
422
+ if msg != b'READY' :
423
+ raise RuntimeError (f"Engine { eng_id } failed: { msg .decode ()} " )
424
+ logger .info ("Core engine process %d ready." , eng_id )
425
+ identities .discard (eng_id )
426
+
427
+ # Double check that the process are running.
428
+ for engine in self .resources .core_engines :
429
+ proc = engine .proc_handle .proc
430
+ if proc .exitcode is not None :
431
+ raise RuntimeError (f"Engine proc { proc .name } not running" )
432
+
405
433
def _init_core_engines (
406
434
self ,
407
435
vllm_config : VllmConfig ,
@@ -494,9 +522,10 @@ def get_output(self) -> EngineCoreOutputs:
494
522
return self .outputs_queue .get ()
495
523
496
524
def _send_input (self , request_type : EngineCoreRequestType , request : Any ):
497
- # (RequestType, SerializedRequest)
498
- msg = (request_type .value , self .encoder .encode (request ))
499
- self .core_engine .send_multipart (msg )
525
+ # (Identity, RequestType, SerializedRequest)
526
+ msg = (self .core_engine .identity , request_type .value ,
527
+ self .encoder .encode (request ))
528
+ self .input_socket .send_multipart (msg , copy = False )
500
529
501
530
def call_utility (self , method : str , * args ) -> Any :
502
531
call_id = uuid .uuid1 ().int >> 64
@@ -625,30 +654,34 @@ async def get_output_async(self) -> EngineCoreOutputs:
625
654
assert self .outputs_queue is not None
626
655
return await self .outputs_queue .get ()
627
656
628
- async def _send_input (self , request_type : EngineCoreRequestType ,
629
- request : Any ) -> None :
630
- await self .core_engine .send_multipart (
631
- (request_type .value , self .encoder .encode (request )))
657
+ def _send_input (self ,
658
+ request_type : EngineCoreRequestType ,
659
+ request : Any ,
660
+ engine : Optional [CoreEngine ] = None ) -> Awaitable [None ]:
661
+ if engine is None :
662
+ engine = self .core_engine
632
663
633
- self ._ensure_output_queue_task ()
664
+ message = (request_type .value , self .encoder .encode (request ))
665
+ return self ._send_input_message (message , engine )
666
+
667
+ def _send_input_message (self , message : tuple [bytes , bytes ],
668
+ engine : CoreEngine ) -> Awaitable [None ]:
669
+ message = (engine .identity , ) + message # type: ignore[assignment]
670
+ return self .input_socket .send_multipart (message , copy = False )
634
671
635
672
async def call_utility_async (self , method : str , * args ) -> Any :
636
673
return await self ._call_utility_async (method ,
637
674
* args ,
638
675
engine = self .core_engine )
639
676
640
- async def _call_utility_async (
641
- self ,
642
- method : str ,
643
- * args ,
644
- engine : CoreEngine ,
645
- ) -> Any :
677
+ async def _call_utility_async (self , method : str , * args ,
678
+ engine : CoreEngine ) -> Any :
646
679
call_id = uuid .uuid1 ().int >> 64
647
680
future = asyncio .get_running_loop ().create_future ()
648
681
self .utility_results [call_id ] = future
649
682
message = (EngineCoreRequestType .UTILITY .value ,
650
683
self .encoder .encode ((call_id , method , args )))
651
- await engine . send_multipart (message )
684
+ await self . _send_input_message (message , engine )
652
685
self ._ensure_output_queue_task ()
653
686
return await future
654
687
@@ -657,6 +690,7 @@ async def add_request_async(self, request: EngineCoreRequest) -> None:
657
690
# tokenized.
658
691
request .prompt = None
659
692
await self ._send_input (EngineCoreRequestType .ADD , request )
693
+ self ._ensure_output_queue_task ()
660
694
661
695
async def abort_requests_async (self , request_ids : list [str ]) -> None :
662
696
if len (request_ids ) > 0 :
@@ -761,15 +795,15 @@ async def add_request_async(self, request: EngineCoreRequest) -> None:
761
795
self .reqs_in_flight [request .request_id ] = chosen_engine
762
796
chosen_engine .num_reqs_in_flight += 1
763
797
if self .num_engines_running >= len (self .core_engines ):
764
- await chosen_engine . send_multipart (msg )
798
+ await self . _send_input_message (msg , chosen_engine )
765
799
else :
766
800
# Send request to chosen engine and dp start loop
767
801
# control message to all other engines.
768
802
self .num_engines_running += len (self .core_engines )
769
803
await asyncio .gather (* [
770
- engine . send_multipart ( msg if engine is
771
- chosen_engine else self .start_dp_msg )
772
- for engine in self .core_engines
804
+ self . _send_input_message (
805
+ msg if engine is chosen_engine else self .start_dp_msg ,
806
+ engine ) for engine in self .core_engines
773
807
])
774
808
775
809
self ._ensure_output_queue_task ()
@@ -794,7 +828,7 @@ async def process_engine_outputs(self: "DPAsyncMPClient",
794
828
# sure to start the other engines:
795
829
self .num_engines_running = len (self .core_engines )
796
830
coros = [
797
- engine . send_multipart (self .start_dp_msg )
831
+ self . _send_input_message (self .start_dp_msg , engine )
798
832
for engine in self .core_engines
799
833
if not engine .num_reqs_in_flight
800
834
]
@@ -820,5 +854,5 @@ async def abort_requests_async(self, request_ids: list[str]) -> None:
820
854
821
855
async def _abort_requests (self , request_ids : list [str ],
822
856
engine : CoreEngine ) -> None :
823
- await engine . send_multipart (( EngineCoreRequestType .ABORT . value ,
824
- self . encoder . encode ( request_ids )) )
857
+ await self . _send_input ( EngineCoreRequestType .ABORT , request_ids ,
858
+ engine )
0 commit comments