Skip to content

Commit 9c93e84

Browse files
author
Weichao Luo
committed
one kv trans process per tp.
1 parent 01af703 commit 9c93e84

File tree

5 files changed

+110
-123
lines changed

5 files changed

+110
-123
lines changed

Diff for: lightllm/distributed/pynccl.py

-71
Original file line numberDiff line numberDiff line change
@@ -248,51 +248,6 @@ def all_reduce(self, in_tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, strea
248248
)
249249
return out_tensor
250250

251-
def all_gather(self, output_tensor: torch.Tensor, input_tensor: torch.Tensor, stream=None):
252-
if self.disabled:
253-
return
254-
# nccl communicator created on a specific device
255-
# will only work on tensors on the same device
256-
# otherwise it will cause "illegal memory access"
257-
assert input_tensor.device == self.device, (
258-
f"this nccl communicator is created to work on {self.device}, "
259-
f"but the input tensor is on {input_tensor.device}"
260-
)
261-
if stream is None:
262-
stream = current_stream()
263-
self.nccl.ncclAllGather(
264-
buffer_type(input_tensor.data_ptr()),
265-
buffer_type(output_tensor.data_ptr()),
266-
input_tensor.numel(),
267-
ncclDataTypeEnum.from_torch(input_tensor.dtype),
268-
self.comm,
269-
cudaStream_t(stream.cuda_stream),
270-
)
271-
272-
def reduce_scatter(
273-
self, output_tensor: torch.Tensor, input_tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, stream=None
274-
):
275-
if self.disabled:
276-
return
277-
# nccl communicator created on a specific device
278-
# will only work on tensors on the same device
279-
# otherwise it will cause "illegal memory access"
280-
assert input_tensor.device == self.device, (
281-
f"this nccl communicator is created to work on {self.device}, "
282-
f"but the input tensor is on {input_tensor.device}"
283-
)
284-
if stream is None:
285-
stream = current_stream()
286-
self.nccl.ncclReduceScatter(
287-
buffer_type(input_tensor.data_ptr()),
288-
buffer_type(output_tensor.data_ptr()),
289-
output_tensor.numel(),
290-
ncclDataTypeEnum.from_torch(input_tensor.dtype),
291-
ncclRedOpTypeEnum.from_torch(op),
292-
self.comm,
293-
cudaStream_t(stream.cuda_stream),
294-
)
295-
296251
def send(self, tensor: torch.Tensor, dst: int, stream=None):
297252
if self.disabled:
298253
return
@@ -328,29 +283,3 @@ def recv(self, tensor: torch.Tensor, src: int, stream=None):
328283
self.comm,
329284
cudaStream_t(stream.cuda_stream),
330285
)
331-
332-
def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
333-
if self.disabled:
334-
return
335-
assert tensor.device == self.device, (
336-
f"this nccl communicator is created to work on {self.device}, "
337-
f"but the input tensor is on {tensor.device}"
338-
)
339-
if stream is None:
340-
stream = current_stream()
341-
if src == self.rank:
342-
sendbuff = buffer_type(tensor.data_ptr())
343-
# NCCL requires the sender also to have a receive buffer
344-
recvbuff = buffer_type(tensor.data_ptr())
345-
else:
346-
sendbuff = buffer_type()
347-
recvbuff = buffer_type(tensor.data_ptr())
348-
self.nccl.ncclBroadcast(
349-
sendbuff,
350-
recvbuff,
351-
tensor.numel(),
352-
ncclDataTypeEnum.from_torch(tensor.dtype),
353-
src,
354-
self.comm,
355-
cudaStream_t(stream.cuda_stream),
356-
)

Diff for: lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py

+43-16
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ def create(self, prefill_node_id: str, prefill_ip: str, prefill_port: int, manag
5151

5252
device_index = manager.get_next_device_index()
5353
decode_node_id = manager.args.pd_node_id
54-
task_in_queue = manager.kv_trans_task_in_queue
55-
task_out_queue = manager.kv_trans_task_out_queue
54+
task_in_queue = manager.kv_trans_task_in_queues[device_index]
55+
task_out_queue = manager.kv_trans_task_out_queues[device_index]
5656

5757
task_in_queue.put(
5858
PDTransJoinInfo(
@@ -136,7 +136,6 @@ def kv_move_loop(self):
136136
self.manager.put_to_fail_release_task_queue(move_tasks)
137137

138138
logger.error(f"{func_name} prefill id {self.prefill_node_id} device_index {self.device_index} thread quit")
139-
self.task_in_queue.put(PDTransLeaveInfo(decode_id=self.decode_node_id, prefill_id=self.prefill_node_id))
140139
return
141140

142141
def put_to_radix_loop(self):
@@ -217,6 +216,7 @@ def __del__(self):
217216
try:
218217
self.set_has_error()
219218
self.wait_thread_quit()
219+
self.task_in_queue.put(PDTransLeaveInfo(decode_id=self.decode_node_id, prefill_id=self.prefill_node_id))
220220
if self.ready_to_move_queue is not None:
221221
self.ready_to_move_queue.clear_tasks()
222222
if self.move_finished_queue is not None:
@@ -266,18 +266,31 @@ def __init__(self, args, info_queue: mp.Queue, mem_queues: List[mp.Queue]):
266266
# 需要每个卡有一个锁来规划每次只能有一个tran obj 操作对应显卡上的传输任务。
267267
self.device_locks = [threading.Lock() for _ in range(self.node_world_size)]
268268

269-
# start a single kv trans process
270-
self.kv_trans_task_in_queue = mp.Queue()
271-
self.kv_trans_task_out_queue = mp.Queue()
272269
from .decode_trans_process import start_decode_trans_process
273270

274-
self.kv_trans_process = start_decode_trans_process(
275-
self.args, self.kv_trans_task_in_queue, self.kv_trans_task_out_queue, self.mem_queues
276-
)
271+
self.kv_trans_processes = []
272+
self.kv_trans_task_in_queues = []
273+
self.kv_trans_task_out_queues = []
274+
self.kv_trans_process_alive = []
275+
276+
for device_index in range(self.node_world_size):
277+
kv_trans_task_in_queue = mp.Queue()
278+
kv_trans_task_out_queue = mp.Queue()
279+
kv_trans_process = start_decode_trans_process(
280+
self.args,
281+
device_index,
282+
kv_trans_task_in_queue,
283+
kv_trans_task_out_queue,
284+
self.mem_queues,
285+
)
286+
assert kv_trans_task_out_queue.get(timeout=30) == "proc_start"
287+
self._put_mem_manager_to_mem_queue()
288+
assert kv_trans_task_out_queue.get(timeout=60) == "get_mem_managers_ok"
277289

278-
assert self.kv_trans_task_out_queue.get(timeout=30) == "proc_start"
279-
self._put_mem_manager_to_mem_queue()
280-
assert self.kv_trans_task_out_queue.get(timeout=60) == "get_mem_managers_ok"
290+
self.kv_trans_processes.append(kv_trans_process)
291+
self.kv_trans_task_in_queues.append(kv_trans_task_in_queue)
292+
self.kv_trans_task_out_queues.append(kv_trans_task_out_queue)
293+
self.kv_trans_process_alive.append(True)
281294

282295
return
283296

@@ -462,7 +475,9 @@ def exposed_request_data_transfer(self, tasks: List[KVMoveTask]) -> List[Optiona
462475
return ans_list
463476

464477
def get_next_device_index(self):
465-
counts = [0 for _ in range(self.node_world_size)]
478+
counts = [
479+
0 if self.kv_trans_process_alive[device_id] else (1 << 20) for device_id in range(self.node_world_size)
480+
]
466481
for obj in self.node_id_to_trans_obj.values():
467482
counts[obj.device_index] += 1
468483
device_index = int(np.argmin(counts))
@@ -495,10 +510,22 @@ def remove_trans_obj(self, prefill_node_id):
495510
return
496511

497512
def check_trans_process(self, raise_exception=True):
498-
process = psutil.Process(self.kv_trans_process.pid)
499-
if not (process.is_running() and process.status() != psutil.STATUS_ZOMBIE):
513+
at_least_one_alive = False
514+
for device_id in range(self.node_world_size):
515+
if not self.kv_trans_process_alive[device_id]:
516+
continue
517+
518+
process = psutil.Process(self.kv_trans_processes[device_id].pid)
519+
if not (process.is_running() and process.status() != psutil.STATUS_ZOMBIE):
520+
self.kv_trans_process_alive[device_id] = False
521+
logger.error(f"kv trans process for device: {device_id} dead!!!")
522+
else:
523+
at_least_one_alive = True
524+
525+
if not at_least_one_alive:
500526
if raise_exception:
501-
raise Exception(f"trans process: {self.kv_trans_process.pid} is dead")
527+
raise Exception("All trans process are dead!!!")
528+
502529
return
503530

504531
def timer_loop(self):

Diff for: lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -66,16 +66,17 @@ def _handle_prefill_join(
6666
logger.warning(f"error while connect to prefill node: {e}")
6767

6868

69-
def _init_env(args, task_in_queue: mp.Queue, task_out_queue: mp.Queue, mem_queues: List[mp.Queue]):
69+
def _init_env(args, device_id: int, task_in_queue: mp.Queue, task_out_queue: mp.Queue, mem_queues: List[mp.Queue]):
7070

7171
dp_size_in_node = max(1, args.dp // args.nnodes)
72-
node_world_size = args.tp // args.nnodes
7372

7473
try:
74+
torch.cuda.set_device(device_id)
7575
graceful_registry(inspect.currentframe().f_code.co_name)
7676
task_out_queue.put("proc_start")
77+
7778
mem_managers: List[MemoryManager] = [mem_queue.get(timeout=60) for mem_queue in mem_queues]
78-
assert len(mem_managers) == node_world_size
79+
7980
task_out_queue.put("get_mem_managers_ok")
8081
prefill_to_comm: Dict[int, PyNcclCommunicator] = {}
8182
while True:
@@ -97,12 +98,13 @@ def _init_env(args, task_in_queue: mp.Queue, task_out_queue: mp.Queue, mem_queue
9798

9899
def start_decode_trans_process(
99100
args,
101+
device_id: int,
100102
task_in_queue: mp.Queue,
101103
task_out_queue: mp.Queue,
102104
mem_queues: List[mp.Queue],
103105
):
104-
proc = mp.Process(target=_init_env, args=(args, task_in_queue, task_out_queue, mem_queues))
106+
proc = mp.Process(target=_init_env, args=(args, device_id, task_in_queue, task_out_queue, mem_queues))
105107
proc.start()
106108
assert proc.is_alive()
107-
logger.info("decode trans kv process start!")
109+
logger.info(f"decode trans kv process for device: {device_id} start!")
108110
return proc

Diff for: lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py

+53-26
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class TransProcessObj:
3939
rpyc_conn: object = None # rpyc_con 的连接对象
4040
task_in_queue: mp.Queue = None
4141
task_out_queue: mp.Queue = None
42-
device_index: str = None # 使用的gpu序号
42+
device_index: int = None # 使用的gpu序号
4343
manager: "PrefillKVMoveManager" = None
4444
has_error: bool = False
4545
request_kv_trans_task_queue: TaskQueue = None
@@ -57,15 +57,15 @@ def create(
5757

5858
device_index = manager.get_next_device_index() # 分配 trans 进程使用的显卡
5959
prefill_node_id = manager.args.pd_node_id
60-
task_in_queue = manager.kv_trans_task_in_queue
61-
task_out_queue = manager.kv_trans_task_out_queue
60+
task_in_queue = manager.kv_trans_task_in_queues[device_index]
61+
task_out_queue = manager.kv_trans_task_out_queues[device_index]
6262

6363
task_in_queue.put(
6464
PDTransJoinInfo(
6565
prefill_id=prefill_node_id,
6666
prefill_device_id=device_index,
6767
prefill_ip=manager.host_ip,
68-
prefill_port=manager.kv_trans_port,
68+
prefill_port=manager.kv_trans_ports[device_index],
6969
decode_id=decode_node_id,
7070
decode_device_id=-1,
7171
)
@@ -74,7 +74,7 @@ def create(
7474
# 异步调用, 让decode节点建立与prefill节点进行nccl通信的进程
7575
max_kv_trans_token_num = obtain(
7676
con.root.build_trans_process(
77-
prefill_node_id, manager.host_ip, manager.kv_trans_port, manager.args.max_total_token_num
77+
prefill_node_id, manager.host_ip, manager.kv_trans_ports[device_index], manager.args.max_total_token_num
7878
)
7979
)
8080
self.max_kv_trans_token_num = max_kv_trans_token_num
@@ -237,7 +237,6 @@ def kv_trans_handle_loop(self):
237237
self.manager.put_to_release_task_queue(move_tasks)
238238

239239
logger.error(f"trans kv thread, decode id {self.decode_node_id} device_index {self.device_index} thread quit")
240-
self.task_in_queue.put(PDTransLeaveInfo(decode_id=self.decode_node_id, prefill_id=self.prefill_node_id))
241240
return
242241

243242
def wait_thread_quit(self):
@@ -282,6 +281,7 @@ def __del__(self):
282281
try:
283282
self.set_has_error()
284283
self.wait_thread_quit()
284+
self.task_in_queue.put(PDTransLeaveInfo(decode_id=self.decode_node_id, prefill_id=self.prefill_node_id))
285285
if self.request_kv_trans_task_queue is not None:
286286
self.request_kv_trans_task_queue.clear_tasks()
287287
if self.ready_kv_trans_task_queue is not None:
@@ -329,24 +329,37 @@ def __init__(self, args, info_queue: mp.Queue, mem_queues: List[mp.Queue]):
329329
self.release_tasks_thread.start()
330330

331331
# start a single kv trans process
332-
self.kv_trans_task_in_queue = mp.Queue()
333-
self.kv_trans_task_out_queue = mp.Queue()
334-
from .prefill_trans_process import start_prefill_trans_process
335-
336-
self.kv_trans_port = find_available_port(self.args.pd_p_allowed_port_min, self.args.pd_p_allowed_port_max)
337-
self.kv_trans_process = start_prefill_trans_process(
338-
self.args,
339-
self.host_ip,
340-
self.kv_trans_port,
341-
self.kv_trans_task_in_queue,
342-
self.kv_trans_task_out_queue,
343-
self.mem_queues,
344-
)
345332

346-
assert self.kv_trans_task_out_queue.get(timeout=30) == "proc_start"
347-
self._put_mem_manager_to_mem_queue()
348-
assert self.kv_trans_task_out_queue.get(timeout=60) == "get_mem_managers_ok"
333+
from .prefill_trans_process import start_prefill_trans_process
349334

335+
self.kv_trans_ports = []
336+
self.kv_trans_processes = []
337+
self.kv_trans_task_in_queues = []
338+
self.kv_trans_task_out_queues = []
339+
self.kv_trans_process_alive = []
340+
341+
for device_id in range(self.node_world_size):
342+
kv_trans_task_in_queue = mp.Queue()
343+
kv_trans_task_out_queue = mp.Queue()
344+
kv_trans_port = find_available_port(self.args.pd_p_allowed_port_min, self.args.pd_p_allowed_port_max)
345+
kv_trans_process = start_prefill_trans_process(
346+
self.args,
347+
self.host_ip,
348+
kv_trans_port,
349+
device_id,
350+
kv_trans_task_in_queue,
351+
kv_trans_task_out_queue,
352+
self.mem_queues,
353+
)
354+
assert kv_trans_task_out_queue.get(timeout=30) == "proc_start"
355+
self._put_mem_manager_to_mem_queue()
356+
assert kv_trans_task_out_queue.get(timeout=60) == "get_mem_managers_ok"
357+
358+
self.kv_trans_ports.append(kv_trans_port)
359+
self.kv_trans_processes.append(kv_trans_process)
360+
self.kv_trans_task_in_queues.append(kv_trans_task_in_queue)
361+
self.kv_trans_task_out_queues.append(kv_trans_task_out_queue)
362+
self.kv_trans_process_alive.append(True)
350363
return
351364

352365
def put_to_release_task_queue(self, task: Union[KVMoveTask, List[KVMoveTask]]):
@@ -368,14 +381,28 @@ def handle_release_task_loop(self):
368381
return
369382

370383
def check_trans_process(self, raise_exception=True):
371-
process = psutil.Process(self.kv_trans_process.pid)
372-
if not (process.is_running() and process.status() != psutil.STATUS_ZOMBIE):
384+
at_least_one_alive = False
385+
for device_id in range(self.node_world_size):
386+
if not self.kv_trans_process_alive[device_id]:
387+
continue
388+
389+
process = psutil.Process(self.kv_trans_processes[device_id].pid)
390+
if not (process.is_running() and process.status() != psutil.STATUS_ZOMBIE):
391+
self.kv_trans_process_alive[device_id] = False
392+
logger.error(f"kv trans process for device: {device_id} dead!!!")
393+
else:
394+
at_least_one_alive = True
395+
396+
if not at_least_one_alive:
373397
if raise_exception:
374-
raise Exception(f"trans process: {self.kv_trans_process.pid} is dead")
398+
raise Exception("All trans process are dead!!!")
399+
375400
return
376401

377402
def get_next_device_index(self):
378-
counts = [0 for _ in range(self.node_world_size)]
403+
counts = [
404+
0 if self.kv_trans_process_alive[device_id] else (1 << 20) for device_id in range(self.node_world_size)
405+
]
379406
for obj in self.node_id_to_trans_obj.values():
380407
counts[obj.device_index] += 1
381408
device_index = int(np.argmin(counts))

0 commit comments

Comments
 (0)