Skip to content

Commit eb6d3c2

Browse files
authored
[Core] Eliminate parallel worker per-step task scheduling overhead (#4894)
1 parent 97b0300 commit eb6d3c2

12 files changed

+348
-209
lines changed

vllm/engine/async_llm_engine.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,14 @@ async def step_async(
234234
# Log stats.
235235
self.do_log_stats(scheduler_outputs, output)
236236

237+
if not request_outputs:
238+
# Stop the execute model loop in parallel workers until there are
239+
# more requests to process. This avoids waiting indefinitely in
240+
# torch.distributed ops which may otherwise timeout, and unblocks
241+
# the RPC thread in the workers so that they can process any other
242+
# queued control plane messages, such as add/remove lora adapters.
243+
await self.model_executor.stop_remote_worker_execution_loop_async()
244+
237245
return request_outputs
238246

239247
async def encode_request_async(
@@ -687,7 +695,7 @@ async def encode(
687695
multi_modal_data: Multi modal data per request.
688696
689697
Yields:
690-
The output `EmbeddingRequestOutput` objects from the LLMEngine
698+
The output `EmbeddingRequestOutput` objects from the LLMEngine
691699
for the request.
692700
693701
Details:

vllm/engine/llm_engine.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -692,6 +692,14 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
692692
# Log stats.
693693
self.do_log_stats(scheduler_outputs, output)
694694

695+
if not request_outputs:
696+
# Stop the execute model loop in parallel workers until there are
697+
# more requests to process. This avoids waiting indefinitely in
698+
# torch.distributed ops which may otherwise timeout, and unblocks
699+
# the RPC thread in the workers so that they can process any other
700+
# queued control plane messages, such as add/remove lora adapters.
701+
self.model_executor.stop_remote_worker_execution_loop()
702+
695703
return request_outputs
696704

697705
def do_log_stats(

vllm/executor/distributed_gpu_executor.py

Lines changed: 97 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,29 @@
1+
import asyncio
12
from abc import abstractmethod
2-
from typing import Any, Dict, List, Optional, Set, Tuple
3+
from typing import Any, Awaitable, Dict, List, Optional, Set, Tuple, Union
34

45
from vllm.executor.executor_base import ExecutorAsyncBase
56
from vllm.executor.gpu_executor import GPUExecutor
67
from vllm.logger import init_logger
78
from vllm.lora.request import LoRARequest
8-
from vllm.sequence import SamplerOutput
9+
from vllm.sequence import ExecuteModelRequest, SamplerOutput
910

1011
logger = init_logger(__name__)
1112

1213

1314
class DistributedGPUExecutor(GPUExecutor):
1415
"""Abstract superclass of multi-GPU executor implementations."""
1516

17+
def __init__(self, *args, **kwargs):
18+
# This is non-None when the execute model loop is running
19+
# in the parallel workers. It's a coroutine in the AsyncLLMEngine case.
20+
self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None
21+
# Updated by implementations that require additional args to be passed
22+
# to the _run_workers execute_model call
23+
self.extra_execute_model_run_workers_kwargs: Dict[str, Any] = {}
24+
25+
super().__init__(*args, **kwargs)
26+
1627
def determine_num_available_blocks(self) -> Tuple[int, int]:
1728
"""Determine the number of available KV blocks.
1829
@@ -52,13 +63,28 @@ def initialize_cache(self, num_gpu_blocks: int,
5263
num_gpu_blocks=num_gpu_blocks,
5364
num_cpu_blocks=num_cpu_blocks)
5465

55-
def execute_model(self, *args, **kwargs) -> List[SamplerOutput]:
56-
all_outputs = self._run_workers("execute_model",
57-
driver_args=args,
58-
driver_kwargs=kwargs)
66+
def execute_model(
67+
self,
68+
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
69+
if self.parallel_worker_tasks is None:
70+
self.parallel_worker_tasks = self._run_workers(
71+
"start_worker_execution_loop",
72+
async_run_remote_workers_only=True,
73+
**self.extra_execute_model_run_workers_kwargs)
5974

6075
# Only the driver worker returns the sampling results.
61-
return all_outputs[0]
76+
return self._driver_execute_model(execute_model_req)
77+
78+
def stop_remote_worker_execution_loop(self) -> None:
79+
if self.parallel_worker_tasks is None:
80+
return
81+
82+
self._driver_execute_model()
83+
parallel_worker_tasks = self.parallel_worker_tasks
84+
self.parallel_worker_tasks = None
85+
# Ensure that workers exit model loop cleanly
86+
# (this will raise otherwise)
87+
self._wait_for_tasks_completion(parallel_worker_tasks)
6288

6389
def add_lora(self, lora_request: LoRARequest) -> bool:
6490
assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
@@ -88,39 +114,84 @@ def save_sharded_state(
88114
pattern=pattern,
89115
max_size=max_size)
90116

117+
@abstractmethod
118+
def _driver_execute_model(
119+
self,
120+
execute_model_req: Optional[ExecuteModelRequest] = None
121+
) -> List[SamplerOutput]:
122+
"""Run execute_model in the driver worker.
123+
124+
Passing None will cause the driver to stop the model execution
125+
loop running in each of the remote workers.
126+
"""
127+
raise NotImplementedError
128+
91129
@abstractmethod
92130
def _run_workers(
93131
self,
94132
method: str,
95133
*args,
96-
driver_args: Optional[Tuple[Any, ...]] = None,
97-
driver_kwargs: Optional[Dict[str, Any]] = None,
134+
async_run_remote_workers_only: bool = False,
98135
max_concurrent_workers: Optional[int] = None,
99136
**kwargs,
100137
) -> Any:
101-
"""Runs the given method on all workers."""
138+
"""Runs the given method on all workers.
139+
140+
Args:
141+
async_run_remote_workers_only: If True the method will be run only
142+
in the remote workers, not the driver worker. It will also be
143+
run asynchronously and return a list of futures rather than
144+
blocking on the results.
145+
"""
146+
raise NotImplementedError
147+
148+
@abstractmethod
149+
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
150+
"""Wait for futures returned from _run_workers() with
151+
async_run_remote_workers_only to complete."""
102152
raise NotImplementedError
103153

104154

105155
class DistributedGPUExecutorAsync(DistributedGPUExecutor, ExecutorAsyncBase):
106156

157+
async def execute_model_async(
158+
self,
159+
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
160+
if self.parallel_worker_tasks is None:
161+
# Start model execution loop running in the parallel workers
162+
self.parallel_worker_tasks = asyncio.create_task(
163+
self._start_worker_execution_loop())
164+
165+
# Only the driver worker returns the sampling results.
166+
return await self._driver_execute_model_async(execute_model_req)
167+
168+
async def stop_remote_worker_execution_loop_async(self) -> None:
169+
if self.parallel_worker_tasks is None:
170+
return
171+
172+
await self._driver_execute_model_async()
173+
parallel_worker_tasks = self.parallel_worker_tasks
174+
self.parallel_worker_tasks = None
175+
# Ensure that workers exit model loop cleanly
176+
# (this will raise otherwise)
177+
await parallel_worker_tasks
178+
107179
@abstractmethod
108-
async def _run_workers_async(
180+
async def _driver_execute_model_async(
109181
self,
110-
method: str,
111-
*args,
112-
driver_args: Optional[Tuple[Any, ...]] = None,
113-
driver_kwargs: Optional[Dict[str, Any]] = None,
114-
**kwargs,
115-
) -> Any:
116-
"""Runs the given method on all workers."""
117-
raise NotImplementedError
182+
execute_model_req: Optional[ExecuteModelRequest] = None
183+
) -> List[SamplerOutput]:
184+
"""Execute the model asynchronously in the driver worker.
118185
119-
async def execute_model_async(self, *args,
120-
**kwargs) -> List[SamplerOutput]:
121-
all_outputs = await self._run_workers_async("execute_model",
122-
driver_args=args,
123-
driver_kwargs=kwargs)
186+
Passing None will cause the driver to stop the model execution
187+
loop running in each of the remote workers.
188+
"""
189+
raise NotImplementedError
124190

125-
# Only the driver worker returns the sampling results.
126-
return all_outputs[0]
191+
@abstractmethod
192+
async def _start_worker_execution_loop(self):
193+
"""Run execution loop on all workers. It guarantees all workers run
194+
the loop or None of them is running the loop. Loop can be stopped by
195+
`stop_remote_worker_execution_loop`.
196+
The API is idempotent (guarantee only 1 loop run at any moment)."""
197+
raise NotImplementedError

vllm/executor/executor_base.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,10 @@ def execute_model(
7474
"""Executes at least one model step on the given sequences."""
7575
raise NotImplementedError
7676

77+
def stop_remote_worker_execution_loop(self) -> None:
78+
"""Releases parallel workers from model loop."""
79+
return
80+
7781
@abstractmethod
7882
def add_lora(self, lora_request: LoRARequest) -> bool:
7983
raise NotImplementedError
@@ -109,6 +113,10 @@ async def execute_model_async(
109113
"""Executes one model step on the given sequences."""
110114
raise NotImplementedError
111115

116+
async def stop_remote_worker_execution_loop_async(self) -> None:
117+
"""Releases parallel workers from model loop."""
118+
return
119+
112120
async def check_health_async(self) -> None:
113121
"""Checks if the executor is healthy. If not, it should raise an
114122
exception."""

vllm/executor/multiproc_gpu_executor.py

Lines changed: 44 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import asyncio
22
import os
33
from functools import partial
4-
from typing import Any, Dict, Optional, Tuple
4+
from typing import Any, List, Optional
55

66
from vllm.executor.distributed_gpu_executor import ( # yapf: disable
77
DistributedGPUExecutor, DistributedGPUExecutorAsync)
88
from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
99
ResultHandler, WorkerMonitor)
1010
from vllm.logger import init_logger
11+
from vllm.sequence import ExecuteModelRequest, SamplerOutput
1112
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
1213
get_vllm_instance_id, make_async)
1314

@@ -71,16 +72,34 @@ def shutdown(self):
7172
None)) is not None:
7273
worker_monitor.close()
7374

75+
def _driver_execute_model(
76+
self,
77+
execute_model_req: Optional[ExecuteModelRequest] = None
78+
) -> List[SamplerOutput]:
79+
"""Run execute_model in the driver worker.
80+
81+
Passing None will cause the driver to stop the model execution
82+
loop running in each of the remote workers.
83+
"""
84+
return self.driver_worker.execute_model(
85+
execute_model_req=execute_model_req)
86+
7487
def _run_workers(
7588
self,
7689
method: str,
7790
*args,
78-
driver_args: Optional[Tuple[Any, ...]] = None,
79-
driver_kwargs: Optional[Dict[str, Any]] = None,
91+
async_run_remote_workers_only: bool = False,
8092
max_concurrent_workers: Optional[int] = None,
8193
**kwargs,
8294
) -> Any:
83-
"""Runs the given method on all workers."""
95+
"""Runs the given method on all workers.
96+
97+
Args:
98+
async_run_remote_workers_only: If True the method will be run only
99+
in the remote workers, not the driver worker. It will also be
100+
run asynchronously and return a list of futures rather than
101+
blocking on the results.
102+
"""
84103

85104
if max_concurrent_workers:
86105
raise NotImplementedError(
@@ -92,15 +111,12 @@ def _run_workers(
92111
for worker in self.workers
93112
]
94113

95-
if driver_args is None:
96-
driver_args = args
97-
if driver_kwargs is None:
98-
driver_kwargs = kwargs
114+
if async_run_remote_workers_only:
115+
# Just return futures
116+
return worker_outputs
99117

100-
# Start the driver worker after all the ray workers.
101118
driver_worker_method = getattr(self.driver_worker, method)
102-
driver_worker_output = driver_worker_method(*driver_args,
103-
**driver_kwargs)
119+
driver_worker_output = driver_worker_method(*args, **kwargs)
104120

105121
# Get the results of the workers.
106122
return [driver_worker_output
@@ -111,30 +127,29 @@ def check_health(self) -> None:
111127
if not self.worker_monitor.is_alive():
112128
raise RuntimeError("Worker processes are not running")
113129

130+
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
131+
"""Wait for futures returned from _run_workers() with
132+
async_run_remote_workers_only to complete."""
133+
for result in parallel_worker_tasks:
134+
result.get()
135+
114136

115137
class MultiprocessingGPUExecutorAsync(MultiprocessingGPUExecutor,
116138
DistributedGPUExecutorAsync):
117139

118-
async def _run_workers_async(
119-
self,
120-
method: str,
121-
*args,
122-
driver_args: Optional[Tuple[Any, ...]] = None,
123-
driver_kwargs: Optional[Dict[str, Any]] = None,
124-
**kwargs,
125-
) -> Any:
126-
"""Runs the given method on all workers."""
127-
if driver_args is None:
128-
driver_args = args
129-
if driver_kwargs is None:
130-
driver_kwargs = kwargs
140+
def __init__(self, *args, **kwargs):
141+
super().__init__(*args, **kwargs)
142+
self.driver_exec_model = make_async(self.driver_worker.execute_model)
131143

132-
driver_executor = make_async(getattr(self.driver_worker, method))
144+
async def _driver_execute_model_async(
145+
self,
146+
execute_model_req: Optional[ExecuteModelRequest] = None
147+
) -> List[SamplerOutput]:
148+
return await self.driver_exec_model(execute_model_req)
133149

134-
# Run all the workers asynchronously.
135-
coros = [driver_executor(*driver_args, **driver_kwargs)] + [
136-
worker.execute_method_async(method, *args, **kwargs)
150+
async def _start_worker_execution_loop(self):
151+
coros = [
152+
worker.execute_method_async("start_worker_execution_loop")
137153
for worker in self.workers
138154
]
139-
140155
return await asyncio.gather(*coros)

0 commit comments

Comments
 (0)