29
29
thread_local_data = threading .local ()
30
30
31
31
KV_MOVE_MAX_NUM = 16
32
+ KV_MOVE_MAX_RESTART_CNT = 3
32
33
33
34
34
35
@dataclass
35
36
class TransProcessObj :
36
37
prefill_node_id : int = None
38
+ process : mp .Process = None
37
39
task_in_queue : mp .Queue = None
38
40
task_out_queue : mp .Queue = None
39
- prefill_ip : str = None
40
- prefill_port : int = None
41
+ pd_prefill_nccl_ip : str = None
42
+ pd_prefill_nccl_port : int = None
41
43
device_index : int = None
42
44
manager : "DecodeKVMoveManager" = None
43
45
has_error : bool = False
@@ -47,32 +49,36 @@ class TransProcessObj:
47
49
put_to_radix_thread : threading .Thread = None
48
50
latest_check_time : float = None
49
51
50
- def create (self , prefill_node_id : str , prefill_ip : str , prefill_port : int , manager : "DecodeKVMoveManager" ):
52
+ def create (
53
+ self , prefill_node_id : str , pd_prefill_nccl_ip : str , pd_prefill_nccl_port : int , manager : "DecodeKVMoveManager"
54
+ ):
51
55
52
56
device_index = manager .get_next_device_index ()
53
57
decode_node_id = manager .args .pd_node_id
54
58
task_in_queue = manager .kv_trans_task_in_queues [device_index ]
55
59
task_out_queue = manager .kv_trans_task_out_queues [device_index ]
56
60
57
- task_in_queue .put (
58
- PDTransJoinInfo (
59
- prefill_id = prefill_node_id ,
60
- prefill_device_id = - 1 ,
61
- prefill_ip = prefill_ip ,
62
- prefill_port = prefill_port ,
63
- decode_id = decode_node_id ,
64
- decode_device_id = device_index ,
61
+ with manager .device_locks [device_index ]:
62
+ task_in_queue .put (
63
+ PDTransJoinInfo (
64
+ prefill_id = prefill_node_id ,
65
+ prefill_device_id = - 1 ,
66
+ pd_prefill_nccl_ip = pd_prefill_nccl_ip ,
67
+ pd_prefill_nccl_port = pd_prefill_nccl_port ,
68
+ decode_id = decode_node_id ,
69
+ decode_device_id = device_index ,
70
+ )
65
71
)
66
- )
67
- assert task_out_queue .get (timeout = 60 ) == "nccl_ok"
72
+ assert task_out_queue .get (timeout = 60 ) == "nccl_ok"
68
73
69
74
self .prefill_node_id = prefill_node_id
70
75
self .decode_node_id = decode_node_id
71
76
self .task_in_queue = task_in_queue
72
77
self .task_out_queue = task_out_queue
73
- self .prefill_ip = prefill_ip
74
- self .prefill_port = prefill_port
78
+ self .pd_prefill_nccl_ip = pd_prefill_nccl_ip
79
+ self .pd_prefill_nccl_port = pd_prefill_nccl_port
75
80
self .device_index = device_index
81
+ self .process = manager .kv_trans_processes [device_index ]
76
82
77
83
self .manager = manager
78
84
self .latest_check_time = time .time ()
@@ -90,6 +96,20 @@ def create(self, prefill_node_id: str, prefill_ip: str, prefill_port: int, manag
90
96
self .put_to_radix_thread .start ()
91
97
return
92
98
99
+ def check_trans_process (self , raise_exception = True ):
100
+ process = psutil .Process (self .process .pid )
101
+ if not (process .is_running () and process .status () != psutil .STATUS_ZOMBIE ):
102
+ self .set_has_error ()
103
+ if raise_exception :
104
+ raise Exception (f"trans process: { self .process .pid } is dead" )
105
+ return
106
+
107
+ def timer_to_check_status (self , raise_exception = True ):
108
+ if time .time () - self .latest_check_time >= 2.0 :
109
+ self .latest_check_time = time .time ()
110
+ self .check_trans_process (raise_exception = raise_exception )
111
+ return
112
+
93
113
def _transfer_kv (self , move_tasks : List [KVMoveTask ]):
94
114
with self .manager .device_locks [self .device_index ]:
95
115
self .task_in_queue .put (move_tasks .copy (), timeout = 10 )
@@ -120,6 +140,7 @@ def kv_move_loop(self):
120
140
logger .info (f"{ func_name } get task { task .to_decode_log_info ()} " )
121
141
122
142
try :
143
+ self .timer_to_check_status (raise_exception = True )
123
144
if not kv_trans_use_p2p ():
124
145
with self .manager .kv_trans_lock :
125
146
self ._transfer_kv (move_tasks )
@@ -150,6 +171,7 @@ def put_to_radix_loop(self):
150
171
logger .info (f"{ func_name } get put radix task { task .to_decode_log_info ()} " )
151
172
152
173
try :
174
+ self .timer_to_check_status (raise_exception = True )
153
175
# random to check stats
154
176
self .manager ._put_kv_received_to_radix_cache (move_tasks .copy ())
155
177
for task in move_tasks .copy ():
@@ -266,31 +288,17 @@ def __init__(self, args, info_queue: mp.Queue, mem_queues: List[mp.Queue]):
266
288
# 需要每个卡有一个锁来规划每次只能有一个tran obj 操作对应显卡上的传输任务。
267
289
self .device_locks = [threading .Lock () for _ in range (self .node_world_size )]
268
290
269
- from .decode_trans_process import start_decode_trans_process
270
-
271
291
self .kv_trans_processes = []
272
292
self .kv_trans_task_in_queues = []
273
293
self .kv_trans_task_out_queues = []
274
- self .kv_trans_process_alive = []
275
-
276
- for device_index in range (self .node_world_size ):
277
- kv_trans_task_in_queue = mp .Queue ()
278
- kv_trans_task_out_queue = mp .Queue ()
279
- kv_trans_process = start_decode_trans_process (
280
- self .args ,
281
- device_index ,
282
- kv_trans_task_in_queue ,
283
- kv_trans_task_out_queue ,
284
- self .mem_queues ,
285
- )
286
- assert kv_trans_task_out_queue .get (timeout = 30 ) == "proc_start"
287
- self ._put_mem_manager_to_mem_queue ()
288
- assert kv_trans_task_out_queue .get (timeout = 60 ) == "get_mem_managers_ok"
294
+ self .kv_trans_process_restart_cnt = []
289
295
290
- self .kv_trans_processes .append (kv_trans_process )
291
- self .kv_trans_task_in_queues .append (kv_trans_task_in_queue )
292
- self .kv_trans_task_out_queues .append (kv_trans_task_out_queue )
293
- self .kv_trans_process_alive .append (True )
296
+ for device_id in range (self .node_world_size ):
297
+ self .kv_trans_task_in_queues .append (mp .Queue ())
298
+ self .kv_trans_task_out_queues .append (mp .Queue ())
299
+ self .kv_trans_process_restart_cnt .append (0 )
300
+ self .kv_trans_processes .append (None )
301
+ assert self .start_trans_process (device_id )
294
302
295
303
return
296
304
@@ -400,17 +408,19 @@ def exposed_check_alive(self):
400
408
# 用于 prefill node check 通信连接的状态。
401
409
return
402
410
403
- def exposed_build_trans_process (self , prefill_node_id , prefill_ip , prefill_port , prefill_node_max_kv_trans_num ):
404
- prefill_node_id , prefill_ip , prefill_port , prefill_node_max_kv_trans_num = list (
405
- map (obtain , [prefill_node_id , prefill_ip , prefill_port , prefill_node_max_kv_trans_num ])
411
+ def exposed_build_trans_process (
412
+ self , prefill_node_id , pd_prefill_nccl_ip , pd_prefill_nccl_port , prefill_node_max_kv_trans_num
413
+ ):
414
+ prefill_node_id , pd_prefill_nccl_ip , pd_prefill_nccl_port , prefill_node_max_kv_trans_num = list (
415
+ map (obtain , [prefill_node_id , pd_prefill_nccl_ip , pd_prefill_nccl_port , prefill_node_max_kv_trans_num ])
406
416
)
407
417
thread_local_data .prefill_node_id = prefill_node_id
408
418
409
- logger .info (f"build trans infos { prefill_node_id } { prefill_ip } { prefill_port } " )
419
+ logger .info (f"build trans infos { prefill_node_id } { pd_prefill_nccl_ip } { pd_prefill_nccl_port } " )
410
420
# 如果有历史残留,一并移除
411
421
self .remove_trans_obj (prefill_node_id )
412
422
tran_obj = TransProcessObj ()
413
- tran_obj .create (prefill_node_id , prefill_ip , prefill_port , self )
423
+ tran_obj .create (prefill_node_id , pd_prefill_nccl_ip , pd_prefill_nccl_port , self )
414
424
self .node_id_to_trans_obj [prefill_node_id ] = tran_obj
415
425
return min (prefill_node_max_kv_trans_num , self .args .max_total_token_num )
416
426
@@ -476,7 +486,7 @@ def exposed_request_data_transfer(self, tasks: List[KVMoveTask]) -> List[Optiona
476
486
477
487
def get_next_device_index (self ):
478
488
counts = [
479
- 0 if self .kv_trans_process_alive [ device_id ] else (1 << 20 ) for device_id in range (self .node_world_size )
489
+ 0 if self .is_kv_trans_process_alive ( device_id ) else (1 << 20 ) for device_id in range (self .node_world_size )
480
490
]
481
491
for obj in self .node_id_to_trans_obj .values ():
482
492
counts [obj .device_index ] += 1
@@ -509,16 +519,60 @@ def remove_trans_obj(self, prefill_node_id):
509
519
trans_obj .set_has_error ()
510
520
return
511
521
522
+ def remove_trans_obj_by_deviceid (self , device_id ):
523
+ for node_id , t_obj in self .node_id_to_trans_obj .items ():
524
+ if t_obj .device_index == device_id :
525
+ self .remove_dead_trans_obj (node_id )
526
+
527
+ def start_trans_process (self , device_id : int ):
528
+ task_in_queue = self .kv_trans_task_in_queues [device_id ]
529
+ task_out_queue = self .kv_trans_task_out_queues [device_id ]
530
+ self .kv_trans_process_restart_cnt [device_id ] += 1
531
+
532
+ if self .kv_trans_processes [device_id ]:
533
+ # force kill
534
+ try :
535
+ self .remove_trans_obj_by_deviceid (device_id )
536
+ process = psutil .Process (self .kv_trans_processes [device_id ].pid )
537
+ process .kill ()
538
+ self .kv_trans_processes [device_id ] = None
539
+ except Exception :
540
+ pass
541
+
542
+ try :
543
+ from .decode_trans_process import start_decode_trans_process
544
+
545
+ kv_trans_process = start_decode_trans_process (
546
+ self .args ,
547
+ device_id ,
548
+ task_in_queue ,
549
+ task_out_queue ,
550
+ self .mem_queues ,
551
+ )
552
+ assert task_out_queue .get (timeout = 30 ) == "proc_start"
553
+ self ._put_mem_manager_to_mem_queue ()
554
+ assert task_out_queue .get (timeout = 60 ) == "get_mem_managers_ok"
555
+
556
+ self .kv_trans_processes [device_id ] = kv_trans_process
557
+
558
+ return True
559
+ except Exception as e :
560
+ logger .warning (f"Failed start kv trans process for device { device_id } : { e } " )
561
+ return False
562
+
563
+ def is_kv_trans_process_alive (self , device_id ):
564
+ return self .kv_trans_process_restart_cnt [device_id ] <= KV_MOVE_MAX_RESTART_CNT
565
+
512
566
def check_trans_process (self , raise_exception = True ):
513
567
at_least_one_alive = False
514
568
for device_id in range (self .node_world_size ):
515
- if not self .kv_trans_process_alive [ device_id ] :
569
+ if not self .is_kv_trans_process_alive ( device_id ) :
516
570
continue
517
571
518
572
process = psutil .Process (self .kv_trans_processes [device_id ].pid )
519
573
if not (process .is_running () and process .status () != psutil .STATUS_ZOMBIE ):
520
- self . kv_trans_process_alive [ device_id ] = False
521
- logger . error ( f"kv trans process for device: { device_id } dead!!!" )
574
+ logger . error ( f"kv trans process for device: { device_id } dead!!!, try start again..." )
575
+ self . start_trans_process ( device_id )
522
576
else :
523
577
at_least_one_alive = True
524
578
@@ -530,17 +584,24 @@ def check_trans_process(self, raise_exception=True):
530
584
531
585
def timer_loop (self ):
532
586
try :
533
- last_check_time = time .time ()
534
587
while True :
535
588
self ._unfrozen_time_out_reqs_tokens ()
536
589
time .sleep (3.5 )
537
- if last_check_time - time .time () > 10.0 :
538
- self .check_trans_process ()
539
- last_check_time = time .time ()
540
590
except (BaseException , RuntimeError ) as e :
541
591
logger .exception (str (e ))
542
592
raise e
543
593
594
+ def check_trans_process_loop (self ):
595
+ try :
596
+ while True :
597
+ self .check_trans_process ()
598
+ time .sleep (10.0 )
599
+ except (BaseException , RuntimeError ) as e :
600
+ logger .exception (str (e ))
601
+ # kill parent process if any exception occurred
602
+ os .kill (os .getppid (), signal .SIGTERM )
603
+ raise e
604
+
544
605
545
606
def _init_env (args , info_queue : mp .Queue , mem_queues : List [mp .Queue ], event : mp .Event ):
546
607
import lightllm .utils .rpyc_fix_utils as _
@@ -552,6 +613,9 @@ def _init_env(args, info_queue: mp.Queue, mem_queues: List[mp.Queue], event: mp.
552
613
t = ThreadedServer (manager , port = args .pd_decode_rpyc_port , protocol_config = {"allow_pickle" : True })
553
614
threading .Thread (target = lambda : t .start (), daemon = True ).start ()
554
615
616
+ kv_trans_process_check = threading .Thread (target = manager .check_trans_process_loop , daemon = True )
617
+ kv_trans_process_check .start ()
618
+
555
619
event .set ()
556
620
manager .timer_loop ()
557
621
return
0 commit comments