|
| 1 | +import asyncio |
1 | 2 | from abc import abstractmethod
|
2 |
| -from typing import Any, Dict, List, Optional, Set, Tuple |
| 3 | +from typing import Any, Awaitable, Dict, List, Optional, Set, Tuple, Union |
3 | 4 |
|
4 | 5 | from vllm.executor.executor_base import ExecutorAsyncBase
|
5 | 6 | from vllm.executor.gpu_executor import GPUExecutor
|
6 | 7 | from vllm.logger import init_logger
|
7 | 8 | from vllm.lora.request import LoRARequest
|
8 |
| -from vllm.sequence import SamplerOutput |
| 9 | +from vllm.sequence import ExecuteModelRequest, SamplerOutput |
9 | 10 |
|
10 | 11 | logger = init_logger(__name__)
|
11 | 12 |
|
12 | 13 |
|
13 | 14 | class DistributedGPUExecutor(GPUExecutor):
|
14 | 15 | """Abstract superclass of multi-GPU executor implementations."""
|
15 | 16 |
|
| 17 | + def __init__(self, *args, **kwargs): |
| 18 | + # This is non-None when the execute model loop is running |
| 19 | + # in the parallel workers. It's a coroutine in the AsyncLLMEngine case. |
| 20 | + self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None |
| 21 | + # Updated by implementations that require additional args to be passed |
| 22 | + # to the _run_workers execute_model call |
| 23 | + self.extra_execute_model_run_workers_kwargs: Dict[str, Any] = {} |
| 24 | + |
| 25 | + super().__init__(*args, **kwargs) |
| 26 | + |
16 | 27 | def determine_num_available_blocks(self) -> Tuple[int, int]:
|
17 | 28 | """Determine the number of available KV blocks.
|
18 | 29 |
|
@@ -52,13 +63,28 @@ def initialize_cache(self, num_gpu_blocks: int,
|
52 | 63 | num_gpu_blocks=num_gpu_blocks,
|
53 | 64 | num_cpu_blocks=num_cpu_blocks)
|
54 | 65 |
|
55 |
| - def execute_model(self, *args, **kwargs) -> List[SamplerOutput]: |
56 |
| - all_outputs = self._run_workers("execute_model", |
57 |
| - driver_args=args, |
58 |
| - driver_kwargs=kwargs) |
| 66 | + def execute_model( |
| 67 | + self, |
| 68 | + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: |
| 69 | + if self.parallel_worker_tasks is None: |
| 70 | + self.parallel_worker_tasks = self._run_workers( |
| 71 | + "start_worker_execution_loop", |
| 72 | + async_run_remote_workers_only=True, |
| 73 | + **self.extra_execute_model_run_workers_kwargs) |
59 | 74 |
|
60 | 75 | # Only the driver worker returns the sampling results.
|
61 |
| - return all_outputs[0] |
| 76 | + return self._driver_execute_model(execute_model_req) |
| 77 | + |
| 78 | + def stop_remote_worker_execution_loop(self) -> None: |
| 79 | + if self.parallel_worker_tasks is None: |
| 80 | + return |
| 81 | + |
| 82 | + self._driver_execute_model() |
| 83 | + parallel_worker_tasks = self.parallel_worker_tasks |
| 84 | + self.parallel_worker_tasks = None |
| 85 | + # Ensure that workers exit model loop cleanly |
| 86 | + # (this will raise otherwise) |
| 87 | + self._wait_for_tasks_completion(parallel_worker_tasks) |
62 | 88 |
|
63 | 89 | def add_lora(self, lora_request: LoRARequest) -> bool:
|
64 | 90 | assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
|
@@ -88,39 +114,84 @@ def save_sharded_state(
|
88 | 114 | pattern=pattern,
|
89 | 115 | max_size=max_size)
|
90 | 116 |
|
| 117 | + @abstractmethod |
| 118 | + def _driver_execute_model( |
| 119 | + self, |
| 120 | + execute_model_req: Optional[ExecuteModelRequest] = None |
| 121 | + ) -> List[SamplerOutput]: |
| 122 | + """Run execute_model in the driver worker. |
| 123 | +
|
| 124 | + Passing None will cause the driver to stop the model execution |
| 125 | + loop running in each of the remote workers. |
| 126 | + """ |
| 127 | + raise NotImplementedError |
| 128 | + |
91 | 129 | @abstractmethod
|
92 | 130 | def _run_workers(
|
93 | 131 | self,
|
94 | 132 | method: str,
|
95 | 133 | *args,
|
96 |
| - driver_args: Optional[Tuple[Any, ...]] = None, |
97 |
| - driver_kwargs: Optional[Dict[str, Any]] = None, |
| 134 | + async_run_remote_workers_only: bool = False, |
98 | 135 | max_concurrent_workers: Optional[int] = None,
|
99 | 136 | **kwargs,
|
100 | 137 | ) -> Any:
|
101 |
| - """Runs the given method on all workers.""" |
| 138 | + """Runs the given method on all workers. |
| 139 | +
|
| 140 | + Args: |
| 141 | + async_run_remote_workers_only: If True the method will be run only |
| 142 | + in the remote workers, not the driver worker. It will also be |
| 143 | + run asynchronously and return a list of futures rather than |
| 144 | + blocking on the results. |
| 145 | + """ |
| 146 | + raise NotImplementedError |
| 147 | + |
| 148 | + @abstractmethod |
| 149 | + def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: |
| 150 | + """Wait for futures returned from _run_workers() with |
| 151 | + async_run_remote_workers_only to complete.""" |
102 | 152 | raise NotImplementedError
|
103 | 153 |
|
104 | 154 |
|
105 | 155 | class DistributedGPUExecutorAsync(DistributedGPUExecutor, ExecutorAsyncBase):
|
106 | 156 |
|
| 157 | + async def execute_model_async( |
| 158 | + self, |
| 159 | + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: |
| 160 | + if self.parallel_worker_tasks is None: |
| 161 | + # Start model execution loop running in the parallel workers |
| 162 | + self.parallel_worker_tasks = asyncio.create_task( |
| 163 | + self._start_worker_execution_loop()) |
| 164 | + |
| 165 | + # Only the driver worker returns the sampling results. |
| 166 | + return await self._driver_execute_model_async(execute_model_req) |
| 167 | + |
| 168 | + async def stop_remote_worker_execution_loop_async(self) -> None: |
| 169 | + if self.parallel_worker_tasks is None: |
| 170 | + return |
| 171 | + |
| 172 | + await self._driver_execute_model_async() |
| 173 | + parallel_worker_tasks = self.parallel_worker_tasks |
| 174 | + self.parallel_worker_tasks = None |
| 175 | + # Ensure that workers exit model loop cleanly |
| 176 | + # (this will raise otherwise) |
| 177 | + await parallel_worker_tasks |
| 178 | + |
107 | 179 | @abstractmethod
|
108 |
| - async def _run_workers_async( |
| 180 | + async def _driver_execute_model_async( |
109 | 181 | self,
|
110 |
| - method: str, |
111 |
| - *args, |
112 |
| - driver_args: Optional[Tuple[Any, ...]] = None, |
113 |
| - driver_kwargs: Optional[Dict[str, Any]] = None, |
114 |
| - **kwargs, |
115 |
| - ) -> Any: |
116 |
| - """Runs the given method on all workers.""" |
117 |
| - raise NotImplementedError |
| 182 | + execute_model_req: Optional[ExecuteModelRequest] = None |
| 183 | + ) -> List[SamplerOutput]: |
| 184 | + """Execute the model asynchronously in the driver worker. |
118 | 185 |
|
119 |
| - async def execute_model_async(self, *args, |
120 |
| - **kwargs) -> List[SamplerOutput]: |
121 |
| - all_outputs = await self._run_workers_async("execute_model", |
122 |
| - driver_args=args, |
123 |
| - driver_kwargs=kwargs) |
| 186 | + Passing None will cause the driver to stop the model execution |
| 187 | + loop running in each of the remote workers. |
| 188 | + """ |
| 189 | + raise NotImplementedError |
124 | 190 |
|
125 |
| - # Only the driver worker returns the sampling results. |
126 |
| - return all_outputs[0] |
| 191 | + @abstractmethod |
| 192 | + async def _start_worker_execution_loop(self): |
| 193 | + """Run execution loop on all workers. It guarantees all workers run |
| 194 | + the loop or None of them is running the loop. Loop can be stopped by |
| 195 | + `stop_remote_worker_execution_loop`. |
| 196 | + The API is idempotent (guarantee only 1 loop run at any moment).""" |
| 197 | + raise NotImplementedError |
0 commit comments