Skip to content

Commit aa3a044

Browse files
author
Weichao Luo
committed
fix.
1 parent 9c93e84 commit aa3a044

File tree

4 files changed

+221
-97
lines changed

4 files changed

+221
-97
lines changed

Diff for: lightllm/server/pd_io_struct.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ class PDTransJoinInfo:
8181
decode_device_id: int
8282
prefill_id: int
8383
prefill_device_id: int
84-
prefill_ip: str
85-
prefill_port: int
84+
pd_prefill_nccl_ip: str
85+
pd_prefill_nccl_port: int
8686

8787

8888
@dataclass

Diff for: lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py

+113-49
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,17 @@
2929
thread_local_data = threading.local()
3030

3131
KV_MOVE_MAX_NUM = 16
32+
KV_MOVE_MAX_RESTART_CNT = 3
3233

3334

3435
@dataclass
3536
class TransProcessObj:
3637
prefill_node_id: int = None
38+
process: mp.Process = None
3739
task_in_queue: mp.Queue = None
3840
task_out_queue: mp.Queue = None
39-
prefill_ip: str = None
40-
prefill_port: int = None
41+
pd_prefill_nccl_ip: str = None
42+
pd_prefill_nccl_port: int = None
4143
device_index: int = None
4244
manager: "DecodeKVMoveManager" = None
4345
has_error: bool = False
@@ -47,32 +49,36 @@ class TransProcessObj:
4749
put_to_radix_thread: threading.Thread = None
4850
latest_check_time: float = None
4951

50-
def create(self, prefill_node_id: str, prefill_ip: str, prefill_port: int, manager: "DecodeKVMoveManager"):
52+
def create(
53+
self, prefill_node_id: str, pd_prefill_nccl_ip: str, pd_prefill_nccl_port: int, manager: "DecodeKVMoveManager"
54+
):
5155

5256
device_index = manager.get_next_device_index()
5357
decode_node_id = manager.args.pd_node_id
5458
task_in_queue = manager.kv_trans_task_in_queues[device_index]
5559
task_out_queue = manager.kv_trans_task_out_queues[device_index]
5660

57-
task_in_queue.put(
58-
PDTransJoinInfo(
59-
prefill_id=prefill_node_id,
60-
prefill_device_id=-1,
61-
prefill_ip=prefill_ip,
62-
prefill_port=prefill_port,
63-
decode_id=decode_node_id,
64-
decode_device_id=device_index,
61+
with manager.device_locks[device_index]:
62+
task_in_queue.put(
63+
PDTransJoinInfo(
64+
prefill_id=prefill_node_id,
65+
prefill_device_id=-1,
66+
pd_prefill_nccl_ip=pd_prefill_nccl_ip,
67+
pd_prefill_nccl_port=pd_prefill_nccl_port,
68+
decode_id=decode_node_id,
69+
decode_device_id=device_index,
70+
)
6571
)
66-
)
67-
assert task_out_queue.get(timeout=60) == "nccl_ok"
72+
assert task_out_queue.get(timeout=60) == "nccl_ok"
6873

6974
self.prefill_node_id = prefill_node_id
7075
self.decode_node_id = decode_node_id
7176
self.task_in_queue = task_in_queue
7277
self.task_out_queue = task_out_queue
73-
self.prefill_ip = prefill_ip
74-
self.prefill_port = prefill_port
78+
self.pd_prefill_nccl_ip = pd_prefill_nccl_ip
79+
self.pd_prefill_nccl_port = pd_prefill_nccl_port
7580
self.device_index = device_index
81+
self.process = manager.kv_trans_processes[device_index]
7682

7783
self.manager = manager
7884
self.latest_check_time = time.time()
@@ -90,6 +96,20 @@ def create(self, prefill_node_id: str, prefill_ip: str, prefill_port: int, manag
9096
self.put_to_radix_thread.start()
9197
return
9298

99+
def check_trans_process(self, raise_exception=True):
100+
process = psutil.Process(self.process.pid)
101+
if not (process.is_running() and process.status() != psutil.STATUS_ZOMBIE):
102+
self.set_has_error()
103+
if raise_exception:
104+
raise Exception(f"trans process: {self.process.pid} is dead")
105+
return
106+
107+
def timer_to_check_status(self, raise_exception=True):
108+
if time.time() - self.latest_check_time >= 2.0:
109+
self.latest_check_time = time.time()
110+
self.check_trans_process(raise_exception=raise_exception)
111+
return
112+
93113
def _transfer_kv(self, move_tasks: List[KVMoveTask]):
94114
with self.manager.device_locks[self.device_index]:
95115
self.task_in_queue.put(move_tasks.copy(), timeout=10)
@@ -120,6 +140,7 @@ def kv_move_loop(self):
120140
logger.info(f"{func_name} get task {task.to_decode_log_info()}")
121141

122142
try:
143+
self.timer_to_check_status(raise_exception=True)
123144
if not kv_trans_use_p2p():
124145
with self.manager.kv_trans_lock:
125146
self._transfer_kv(move_tasks)
@@ -150,6 +171,7 @@ def put_to_radix_loop(self):
150171
logger.info(f"{func_name} get put radix task {task.to_decode_log_info()}")
151172

152173
try:
174+
self.timer_to_check_status(raise_exception=True)
153175
# random to check stats
154176
self.manager._put_kv_received_to_radix_cache(move_tasks.copy())
155177
for task in move_tasks.copy():
@@ -266,31 +288,17 @@ def __init__(self, args, info_queue: mp.Queue, mem_queues: List[mp.Queue]):
266288
# 需要每个卡有一个锁来规划每次只能有一个tran obj 操作对应显卡上的传输任务。
267289
self.device_locks = [threading.Lock() for _ in range(self.node_world_size)]
268290

269-
from .decode_trans_process import start_decode_trans_process
270-
271291
self.kv_trans_processes = []
272292
self.kv_trans_task_in_queues = []
273293
self.kv_trans_task_out_queues = []
274-
self.kv_trans_process_alive = []
275-
276-
for device_index in range(self.node_world_size):
277-
kv_trans_task_in_queue = mp.Queue()
278-
kv_trans_task_out_queue = mp.Queue()
279-
kv_trans_process = start_decode_trans_process(
280-
self.args,
281-
device_index,
282-
kv_trans_task_in_queue,
283-
kv_trans_task_out_queue,
284-
self.mem_queues,
285-
)
286-
assert kv_trans_task_out_queue.get(timeout=30) == "proc_start"
287-
self._put_mem_manager_to_mem_queue()
288-
assert kv_trans_task_out_queue.get(timeout=60) == "get_mem_managers_ok"
294+
self.kv_trans_process_restart_cnt = []
289295

290-
self.kv_trans_processes.append(kv_trans_process)
291-
self.kv_trans_task_in_queues.append(kv_trans_task_in_queue)
292-
self.kv_trans_task_out_queues.append(kv_trans_task_out_queue)
293-
self.kv_trans_process_alive.append(True)
296+
for device_id in range(self.node_world_size):
297+
self.kv_trans_task_in_queues.append(mp.Queue())
298+
self.kv_trans_task_out_queues.append(mp.Queue())
299+
self.kv_trans_process_restart_cnt.append(0)
300+
self.kv_trans_processes.append(None)
301+
assert self.start_trans_process(device_id)
294302

295303
return
296304

@@ -400,17 +408,19 @@ def exposed_check_alive(self):
400408
# 用于 prefill node check 通信连接的状态。
401409
return
402410

403-
def exposed_build_trans_process(self, prefill_node_id, prefill_ip, prefill_port, prefill_node_max_kv_trans_num):
404-
prefill_node_id, prefill_ip, prefill_port, prefill_node_max_kv_trans_num = list(
405-
map(obtain, [prefill_node_id, prefill_ip, prefill_port, prefill_node_max_kv_trans_num])
411+
def exposed_build_trans_process(
412+
self, prefill_node_id, pd_prefill_nccl_ip, pd_prefill_nccl_port, prefill_node_max_kv_trans_num
413+
):
414+
prefill_node_id, pd_prefill_nccl_ip, pd_prefill_nccl_port, prefill_node_max_kv_trans_num = list(
415+
map(obtain, [prefill_node_id, pd_prefill_nccl_ip, pd_prefill_nccl_port, prefill_node_max_kv_trans_num])
406416
)
407417
thread_local_data.prefill_node_id = prefill_node_id
408418

409-
logger.info(f"build trans infos {prefill_node_id} {prefill_ip} {prefill_port}")
419+
logger.info(f"build trans infos {prefill_node_id} {pd_prefill_nccl_ip} {pd_prefill_nccl_port}")
410420
# 如果有历史残留,一并移除
411421
self.remove_trans_obj(prefill_node_id)
412422
tran_obj = TransProcessObj()
413-
tran_obj.create(prefill_node_id, prefill_ip, prefill_port, self)
423+
tran_obj.create(prefill_node_id, pd_prefill_nccl_ip, pd_prefill_nccl_port, self)
414424
self.node_id_to_trans_obj[prefill_node_id] = tran_obj
415425
return min(prefill_node_max_kv_trans_num, self.args.max_total_token_num)
416426

@@ -476,7 +486,7 @@ def exposed_request_data_transfer(self, tasks: List[KVMoveTask]) -> List[Optiona
476486

477487
def get_next_device_index(self):
478488
counts = [
479-
0 if self.kv_trans_process_alive[device_id] else (1 << 20) for device_id in range(self.node_world_size)
489+
0 if self.is_kv_trans_process_alive(device_id) else (1 << 20) for device_id in range(self.node_world_size)
480490
]
481491
for obj in self.node_id_to_trans_obj.values():
482492
counts[obj.device_index] += 1
@@ -509,16 +519,60 @@ def remove_trans_obj(self, prefill_node_id):
509519
trans_obj.set_has_error()
510520
return
511521

522+
def remove_trans_obj_by_deviceid(self, device_id):
523+
for node_id, t_obj in self.node_id_to_trans_obj.items():
524+
if t_obj.device_index == device_id:
525+
self.remove_dead_trans_obj(node_id)
526+
527+
def start_trans_process(self, device_id: int):
528+
task_in_queue = self.kv_trans_task_in_queues[device_id]
529+
task_out_queue = self.kv_trans_task_out_queues[device_id]
530+
self.kv_trans_process_restart_cnt[device_id] += 1
531+
532+
if self.kv_trans_processes[device_id]:
533+
# force kill
534+
try:
535+
self.remove_trans_obj_by_deviceid(device_id)
536+
process = psutil.Process(self.kv_trans_processes[device_id].pid)
537+
process.kill()
538+
self.kv_trans_processes[device_id] = None
539+
except Exception:
540+
pass
541+
542+
try:
543+
from .decode_trans_process import start_decode_trans_process
544+
545+
kv_trans_process = start_decode_trans_process(
546+
self.args,
547+
device_id,
548+
task_in_queue,
549+
task_out_queue,
550+
self.mem_queues,
551+
)
552+
assert task_out_queue.get(timeout=30) == "proc_start"
553+
self._put_mem_manager_to_mem_queue()
554+
assert task_out_queue.get(timeout=60) == "get_mem_managers_ok"
555+
556+
self.kv_trans_processes[device_id] = kv_trans_process
557+
558+
return True
559+
except Exception as e:
560+
logger.warning(f"Failed start kv trans process for device {device_id}: {e}")
561+
return False
562+
563+
def is_kv_trans_process_alive(self, device_id):
564+
return self.kv_trans_process_restart_cnt[device_id] <= KV_MOVE_MAX_RESTART_CNT
565+
512566
def check_trans_process(self, raise_exception=True):
513567
at_least_one_alive = False
514568
for device_id in range(self.node_world_size):
515-
if not self.kv_trans_process_alive[device_id]:
569+
if not self.is_kv_trans_process_alive(device_id):
516570
continue
517571

518572
process = psutil.Process(self.kv_trans_processes[device_id].pid)
519573
if not (process.is_running() and process.status() != psutil.STATUS_ZOMBIE):
520-
self.kv_trans_process_alive[device_id] = False
521-
logger.error(f"kv trans process for device: {device_id} dead!!!")
574+
logger.error(f"kv trans process for device: {device_id} dead!!!, try start again...")
575+
self.start_trans_process(device_id)
522576
else:
523577
at_least_one_alive = True
524578

@@ -530,17 +584,24 @@ def check_trans_process(self, raise_exception=True):
530584

531585
def timer_loop(self):
532586
try:
533-
last_check_time = time.time()
534587
while True:
535588
self._unfrozen_time_out_reqs_tokens()
536589
time.sleep(3.5)
537-
if last_check_time - time.time() > 10.0:
538-
self.check_trans_process()
539-
last_check_time = time.time()
540590
except (BaseException, RuntimeError) as e:
541591
logger.exception(str(e))
542592
raise e
543593

594+
def check_trans_process_loop(self):
595+
try:
596+
while True:
597+
self.check_trans_process()
598+
time.sleep(10.0)
599+
except (BaseException, RuntimeError) as e:
600+
logger.exception(str(e))
601+
# kill parent process if any exception occurred
602+
os.kill(os.getppid(), signal.SIGTERM)
603+
raise e
604+
544605

545606
def _init_env(args, info_queue: mp.Queue, mem_queues: List[mp.Queue], event: mp.Event):
546607
import lightllm.utils.rpyc_fix_utils as _
@@ -552,6 +613,9 @@ def _init_env(args, info_queue: mp.Queue, mem_queues: List[mp.Queue], event: mp.
552613
t = ThreadedServer(manager, port=args.pd_decode_rpyc_port, protocol_config={"allow_pickle": True})
553614
threading.Thread(target=lambda: t.start(), daemon=True).start()
554615

616+
kv_trans_process_check = threading.Thread(target=manager.check_trans_process_loop, daemon=True)
617+
kv_trans_process_check.start()
618+
555619
event.set()
556620
manager.timer_loop()
557621
return

Diff for: lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def _handle_prefill_join(
5353
):
5454
try:
5555
store_client = TCPStore(
56-
host_name=node_info.prefill_ip, port=node_info.prefill_port, is_master=False, use_libuv=False
56+
host_name=node_info.pd_prefill_nccl_ip, port=node_info.pd_prefill_nccl_port, is_master=False, use_libuv=True
5757
)
5858
group = StatelessP2PProcessGroup.create(
5959
src_id=node_info.prefill_id, dest_id=node_info.decode_id, is_server=False, store=store_client

0 commit comments

Comments
 (0)