@@ -39,7 +39,7 @@ class TransProcessObj:
39
39
rpyc_conn : object = None # rpyc_con 的连接对象
40
40
task_in_queue : mp .Queue = None
41
41
task_out_queue : mp .Queue = None
42
- device_index : str = None # 使用的gpu序号
42
+ device_index : int = None # 使用的gpu序号
43
43
manager : "PrefillKVMoveManager" = None
44
44
has_error : bool = False
45
45
request_kv_trans_task_queue : TaskQueue = None
@@ -57,15 +57,15 @@ def create(
57
57
58
58
device_index = manager .get_next_device_index () # 分配 trans 进程使用的显卡
59
59
prefill_node_id = manager .args .pd_node_id
60
- task_in_queue = manager .kv_trans_task_in_queue
61
- task_out_queue = manager .kv_trans_task_out_queue
60
+ task_in_queue = manager .kv_trans_task_in_queues [ device_index ]
61
+ task_out_queue = manager .kv_trans_task_out_queues [ device_index ]
62
62
63
63
task_in_queue .put (
64
64
PDTransJoinInfo (
65
65
prefill_id = prefill_node_id ,
66
66
prefill_device_id = device_index ,
67
67
prefill_ip = manager .host_ip ,
68
- prefill_port = manager .kv_trans_port ,
68
+ prefill_port = manager .kv_trans_ports [ device_index ] ,
69
69
decode_id = decode_node_id ,
70
70
decode_device_id = - 1 ,
71
71
)
@@ -74,7 +74,7 @@ def create(
74
74
# 异步调用, 让decode节点建立与prefill节点进行nccl通信的进程
75
75
max_kv_trans_token_num = obtain (
76
76
con .root .build_trans_process (
77
- prefill_node_id , manager .host_ip , manager .kv_trans_port , manager .args .max_total_token_num
77
+ prefill_node_id , manager .host_ip , manager .kv_trans_ports [ device_index ] , manager .args .max_total_token_num
78
78
)
79
79
)
80
80
self .max_kv_trans_token_num = max_kv_trans_token_num
@@ -237,7 +237,6 @@ def kv_trans_handle_loop(self):
237
237
self .manager .put_to_release_task_queue (move_tasks )
238
238
239
239
logger .error (f"trans kv thread, decode id { self .decode_node_id } device_index { self .device_index } thread quit" )
240
- self .task_in_queue .put (PDTransLeaveInfo (decode_id = self .decode_node_id , prefill_id = self .prefill_node_id ))
241
240
return
242
241
243
242
def wait_thread_quit (self ):
@@ -282,6 +281,7 @@ def __del__(self):
282
281
try :
283
282
self .set_has_error ()
284
283
self .wait_thread_quit ()
284
+ self .task_in_queue .put (PDTransLeaveInfo (decode_id = self .decode_node_id , prefill_id = self .prefill_node_id ))
285
285
if self .request_kv_trans_task_queue is not None :
286
286
self .request_kv_trans_task_queue .clear_tasks ()
287
287
if self .ready_kv_trans_task_queue is not None :
@@ -329,24 +329,37 @@ def __init__(self, args, info_queue: mp.Queue, mem_queues: List[mp.Queue]):
329
329
self .release_tasks_thread .start ()
330
330
331
331
# start a single kv trans process
332
- self .kv_trans_task_in_queue = mp .Queue ()
333
- self .kv_trans_task_out_queue = mp .Queue ()
334
- from .prefill_trans_process import start_prefill_trans_process
335
-
336
- self .kv_trans_port = find_available_port (self .args .pd_p_allowed_port_min , self .args .pd_p_allowed_port_max )
337
- self .kv_trans_process = start_prefill_trans_process (
338
- self .args ,
339
- self .host_ip ,
340
- self .kv_trans_port ,
341
- self .kv_trans_task_in_queue ,
342
- self .kv_trans_task_out_queue ,
343
- self .mem_queues ,
344
- )
345
332
346
- assert self .kv_trans_task_out_queue .get (timeout = 30 ) == "proc_start"
347
- self ._put_mem_manager_to_mem_queue ()
348
- assert self .kv_trans_task_out_queue .get (timeout = 60 ) == "get_mem_managers_ok"
333
+ from .prefill_trans_process import start_prefill_trans_process
349
334
335
+ self .kv_trans_ports = []
336
+ self .kv_trans_processes = []
337
+ self .kv_trans_task_in_queues = []
338
+ self .kv_trans_task_out_queues = []
339
+ self .kv_trans_process_alive = []
340
+
341
+ for device_id in range (self .node_world_size ):
342
+ kv_trans_task_in_queue = mp .Queue ()
343
+ kv_trans_task_out_queue = mp .Queue ()
344
+ kv_trans_port = find_available_port (self .args .pd_p_allowed_port_min , self .args .pd_p_allowed_port_max )
345
+ kv_trans_process = start_prefill_trans_process (
346
+ self .args ,
347
+ self .host_ip ,
348
+ kv_trans_port ,
349
+ device_id ,
350
+ kv_trans_task_in_queue ,
351
+ kv_trans_task_out_queue ,
352
+ self .mem_queues ,
353
+ )
354
+ assert kv_trans_task_out_queue .get (timeout = 30 ) == "proc_start"
355
+ self ._put_mem_manager_to_mem_queue ()
356
+ assert kv_trans_task_out_queue .get (timeout = 60 ) == "get_mem_managers_ok"
357
+
358
+ self .kv_trans_ports .append (kv_trans_port )
359
+ self .kv_trans_processes .append (kv_trans_process )
360
+ self .kv_trans_task_in_queues .append (kv_trans_task_in_queue )
361
+ self .kv_trans_task_out_queues .append (kv_trans_task_out_queue )
362
+ self .kv_trans_process_alive .append (True )
350
363
return
351
364
352
365
def put_to_release_task_queue (self , task : Union [KVMoveTask , List [KVMoveTask ]]):
@@ -368,14 +381,28 @@ def handle_release_task_loop(self):
368
381
return
369
382
370
383
def check_trans_process (self , raise_exception = True ):
371
- process = psutil .Process (self .kv_trans_process .pid )
372
- if not (process .is_running () and process .status () != psutil .STATUS_ZOMBIE ):
384
+ at_least_one_alive = False
385
+ for device_id in range (self .node_world_size ):
386
+ if not self .kv_trans_process_alive [device_id ]:
387
+ continue
388
+
389
+ process = psutil .Process (self .kv_trans_processes [device_id ].pid )
390
+ if not (process .is_running () and process .status () != psutil .STATUS_ZOMBIE ):
391
+ self .kv_trans_process_alive [device_id ] = False
392
+ logger .error (f"kv trans process for device: { device_id } dead!!!" )
393
+ else :
394
+ at_least_one_alive = True
395
+
396
+ if not at_least_one_alive :
373
397
if raise_exception :
374
- raise Exception (f"trans process: { self .kv_trans_process .pid } is dead" )
398
+ raise Exception ("All trans process are dead!!!" )
399
+
375
400
return
376
401
377
402
def get_next_device_index (self ):
378
- counts = [0 for _ in range (self .node_world_size )]
403
+ counts = [
404
+ 0 if self .kv_trans_process_alive [device_id ] else (1 << 20 ) for device_id in range (self .node_world_size )
405
+ ]
379
406
for obj in self .node_id_to_trans_obj .values ():
380
407
counts [obj .device_index ] += 1
381
408
device_index = int (np .argmin (counts ))
0 commit comments