8
8
import uuid
9
9
import weakref
10
10
from abc import ABC , abstractmethod
11
- from collections .abc import Awaitable
11
+ from collections .abc import Awaitable , Sequence
12
12
from concurrent .futures import Future
13
13
from dataclasses import dataclass , field
14
14
from threading import Thread
35
35
36
36
_R = TypeVar ('_R' ) # Return type for collective_rpc
37
37
38
- STARTUP_POLL_PERIOD_MS = 10000
39
-
40
38
41
39
class EngineCoreClient (ABC ):
42
40
"""
@@ -263,13 +261,15 @@ def __init__(
263
261
vllm_config : VllmConfig ,
264
262
executor_class : type [Executor ],
265
263
log_stats : bool ,
266
- input_path : str ,
264
+ ctx : Union [ zmq . Context , zmq . asyncio . Context ] ,
267
265
output_path : str ,
268
266
index : int = 0 ,
269
267
local_dp_rank : int = 0 ,
270
268
):
271
- self .index = index
272
- self .identity = index .to_bytes (length = 2 , byteorder = "little" )
269
+ # Paths and sockets for IPC.
270
+ input_path = get_open_zmq_ipc_path ()
271
+ self .input_socket = make_zmq_socket (ctx , input_path ,
272
+ zmq .constants .PUSH )
273
273
try :
274
274
# Start EngineCore in background process.
275
275
self .proc_handle = BackgroundProcHandle (
@@ -291,9 +291,14 @@ def __init__(
291
291
# Ensure socket is closed if process fails to start.
292
292
self .close ()
293
293
294
+ def send_multipart (self , msg_parts : Sequence ):
295
+ return self .input_socket .send_multipart (msg_parts , copy = False )
296
+
294
297
def close (self ):
295
298
if proc_handle := getattr (self , "proc_handle" , None ):
296
299
proc_handle .shutdown ()
300
+ if socket := getattr (self , "input_socket" , None ):
301
+ socket .close (linger = 0 )
297
302
298
303
299
304
@dataclass
@@ -304,7 +309,6 @@ class BackgroundResources:
304
309
ctx : Union [zmq .Context ]
305
310
core_engines : list [CoreEngine ] = field (default_factory = list )
306
311
output_socket : Optional [Union [zmq .Socket , zmq .asyncio .Socket ]] = None
307
- input_socket : Optional [Union [zmq .Socket , zmq .asyncio .Socket ]] = None
308
312
shutdown_path : Optional [str ] = None
309
313
310
314
def __call__ (self ):
@@ -317,8 +321,6 @@ def __call__(self):
317
321
# aren't explicitly closed first.
318
322
if self .output_socket is not None :
319
323
self .output_socket .close (linger = 0 )
320
- if self .input_socket is not None :
321
- self .input_socket .close (linger = 0 )
322
324
if self .shutdown_path is not None :
323
325
# We must ensure that the sync output socket is
324
326
# closed cleanly in its own thread.
@@ -385,51 +387,21 @@ def sigusr1_handler(signum, frame):
385
387
386
388
# Paths and sockets for IPC.
387
389
self .output_path = get_open_zmq_ipc_path ()
388
- input_path = get_open_zmq_ipc_path ()
389
- self .input_socket = make_zmq_socket (self .ctx ,
390
- input_path ,
391
- zmq .ROUTER ,
392
- bind = True )
393
- self .resources .input_socket = self .input_socket
394
390
395
391
new_core_engine = lambda index , local_dp_rank = None : CoreEngine (
396
- vllm_config , executor_class , log_stats , input_path , self .
397
- output_path , index , local_dp_rank )
392
+ vllm_config , executor_class , log_stats , self . ctx , self .output_path ,
393
+ index , local_dp_rank )
398
394
399
395
# Start engine core process(es).
400
396
self ._init_core_engines (vllm_config , new_core_engine ,
401
397
self .resources .core_engines )
402
398
403
399
# Wait for engine core process(es) to start.
404
- self ._wait_for_engine_startup ()
400
+ for engine in self .resources .core_engines :
401
+ engine .proc_handle .wait_for_startup ()
405
402
406
403
self .utility_results : dict [int , AnyFuture ] = {}
407
404
408
- def _wait_for_engine_startup (self ):
409
- # Get a sync handle to the socket which can be sync or async.
410
- sync_input_socket = zmq .Socket .shadow (self .input_socket )
411
-
412
- # Wait for engine core process(es) to send ready messages.
413
- identities = set (eng .index for eng in self .resources .core_engines )
414
- while identities :
415
- while not sync_input_socket .poll (timeout = STARTUP_POLL_PERIOD_MS ):
416
- logger .info ("Waiting for %d core engine proc(s) to start: %s" ,
417
- len (identities ), identities )
418
- eng_id_bytes , msg = sync_input_socket .recv_multipart ()
419
- eng_id = int .from_bytes (eng_id_bytes , byteorder = "little" )
420
- if eng_id not in identities :
421
- raise RuntimeError (f"Unexpected or duplicate engine: { eng_id } " )
422
- if msg != b'READY' :
423
- raise RuntimeError (f"Engine { eng_id } failed: { msg .decode ()} " )
424
- logger .info ("Core engine process %d ready." , eng_id )
425
- identities .discard (eng_id )
426
-
427
- # Double check that the process are running.
428
- for engine in self .resources .core_engines :
429
- proc = engine .proc_handle .proc
430
- if proc .exitcode is not None :
431
- raise RuntimeError (f"Engine proc { proc .name } not running" )
432
-
433
405
def _init_core_engines (
434
406
self ,
435
407
vllm_config : VllmConfig ,
@@ -522,10 +494,9 @@ def get_output(self) -> EngineCoreOutputs:
522
494
return self .outputs_queue .get ()
523
495
524
496
def _send_input (self , request_type : EngineCoreRequestType , request : Any ):
525
- # (Identity, RequestType, SerializedRequest)
526
- msg = (self .core_engine .identity , request_type .value ,
527
- self .encoder .encode (request ))
528
- self .input_socket .send_multipart (msg , copy = False )
497
+ # (RequestType, SerializedRequest)
498
+ msg = (request_type .value , self .encoder .encode (request ))
499
+ self .core_engine .send_multipart (msg )
529
500
530
501
def call_utility (self , method : str , * args ) -> Any :
531
502
call_id = uuid .uuid1 ().int >> 64
@@ -654,34 +625,30 @@ async def get_output_async(self) -> EngineCoreOutputs:
654
625
assert self .outputs_queue is not None
655
626
return await self .outputs_queue .get ()
656
627
657
- def _send_input (self ,
658
- request_type : EngineCoreRequestType ,
659
- request : Any ,
660
- engine : Optional [CoreEngine ] = None ) -> Awaitable [None ]:
661
- if engine is None :
662
- engine = self .core_engine
663
-
664
- message = (request_type .value , self .encoder .encode (request ))
665
- return self ._send_input_message (message , engine )
628
+ async def _send_input (self , request_type : EngineCoreRequestType ,
629
+ request : Any ) -> None :
630
+ await self .core_engine .send_multipart (
631
+ (request_type .value , self .encoder .encode (request )))
666
632
667
- def _send_input_message (self , message : tuple [bytes , bytes ],
668
- engine : CoreEngine ) -> Awaitable [None ]:
669
- message = (engine .identity , ) + message # type: ignore[assignment]
670
- return self .input_socket .send_multipart (message , copy = False )
633
+ self ._ensure_output_queue_task ()
671
634
672
635
async def call_utility_async (self , method : str , * args ) -> Any :
673
636
return await self ._call_utility_async (method ,
674
637
* args ,
675
638
engine = self .core_engine )
676
639
677
- async def _call_utility_async (self , method : str , * args ,
678
- engine : CoreEngine ) -> Any :
640
+ async def _call_utility_async (
641
+ self ,
642
+ method : str ,
643
+ * args ,
644
+ engine : CoreEngine ,
645
+ ) -> Any :
679
646
call_id = uuid .uuid1 ().int >> 64
680
647
future = asyncio .get_running_loop ().create_future ()
681
648
self .utility_results [call_id ] = future
682
649
message = (EngineCoreRequestType .UTILITY .value ,
683
650
self .encoder .encode ((call_id , method , args )))
684
- await self . _send_input_message (message , engine )
651
+ await engine . send_multipart (message )
685
652
self ._ensure_output_queue_task ()
686
653
return await future
687
654
@@ -690,7 +657,6 @@ async def add_request_async(self, request: EngineCoreRequest) -> None:
690
657
# tokenized.
691
658
request .prompt = None
692
659
await self ._send_input (EngineCoreRequestType .ADD , request )
693
- self ._ensure_output_queue_task ()
694
660
695
661
async def abort_requests_async (self , request_ids : list [str ]) -> None :
696
662
if len (request_ids ) > 0 :
@@ -795,15 +761,15 @@ async def add_request_async(self, request: EngineCoreRequest) -> None:
795
761
self .reqs_in_flight [request .request_id ] = chosen_engine
796
762
chosen_engine .num_reqs_in_flight += 1
797
763
if self .num_engines_running >= len (self .core_engines ):
798
- await self . _send_input_message (msg , chosen_engine )
764
+ await chosen_engine . send_multipart (msg )
799
765
else :
800
766
# Send request to chosen engine and dp start loop
801
767
# control message to all other engines.
802
768
self .num_engines_running += len (self .core_engines )
803
769
await asyncio .gather (* [
804
- self . _send_input_message (
805
- msg if engine is chosen_engine else self .start_dp_msg ,
806
- engine ) for engine in self .core_engines
770
+ engine . send_multipart ( msg if engine is
771
+ chosen_engine else self .start_dp_msg )
772
+ for engine in self .core_engines
807
773
])
808
774
809
775
self ._ensure_output_queue_task ()
@@ -828,7 +794,7 @@ async def process_engine_outputs(self: "DPAsyncMPClient",
828
794
# sure to start the other engines:
829
795
self .num_engines_running = len (self .core_engines )
830
796
coros = [
831
- self . _send_input_message (self .start_dp_msg , engine )
797
+ engine . send_multipart (self .start_dp_msg )
832
798
for engine in self .core_engines
833
799
if not engine .num_reqs_in_flight
834
800
]
@@ -854,5 +820,5 @@ async def abort_requests_async(self, request_ids: list[str]) -> None:
854
820
855
821
async def _abort_requests (self , request_ids : list [str ],
856
822
engine : CoreEngine ) -> None :
857
- await self . _send_input ( EngineCoreRequestType .ABORT , request_ids ,
858
- engine )
823
+ await engine . send_multipart (( EngineCoreRequestType .ABORT . value ,
824
+ self . encoder . encode ( request_ids )) )
0 commit comments