|
2 | 2 | import os
|
3 | 3 | import time
|
4 | 4 | from functools import partial
|
5 |
| -from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type, |
6 |
| - Union, AsyncIterator, Callable) |
| 5 | +from typing import (Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, |
| 6 | + Union, AsyncIterator) |
7 | 7 |
|
8 | 8 | from transformers import PreTrainedTokenizer
|
9 | 9 |
|
10 | 10 | from vllm.lora.request import LoRARequest
|
11 | 11 | from vllm.config import ModelConfig
|
12 | 12 | from vllm.engine.arg_utils import AsyncEngineArgs
|
13 | 13 | from vllm.engine.llm_engine import LLMEngine
|
14 |
| -from vllm.engine.ray_utils import initialize_cluster, ray |
| 14 | +from vllm.engine.ray_utils import initialize_ray_cluster, ray |
15 | 15 | from vllm.logger import init_logger
|
16 | 16 | from vllm.outputs import RequestOutput
|
17 | 17 | from vllm.sampling_params import SamplingParams
|
@@ -208,17 +208,10 @@ async def step_async(self) -> List[RequestOutput]:
|
208 | 208 |
|
209 | 209 | if not scheduler_outputs.is_empty():
|
210 | 210 | # Execute the model.
|
211 |
| - all_outputs = await self._run_workers_async( |
212 |
| - "execute_model", |
213 |
| - driver_kwargs={ |
214 |
| - "seq_group_metadata_list": seq_group_metadata_list, |
215 |
| - "blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in, |
216 |
| - "blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out, |
217 |
| - "blocks_to_copy": scheduler_outputs.blocks_to_copy, |
218 |
| - }) |
219 |
| - |
220 |
| - # Only the driver worker returns the sampling results. |
221 |
| - output = all_outputs[0] |
| 211 | + output = await self.model_executor.execute_model_async( |
| 212 | + seq_group_metadata_list, scheduler_outputs.blocks_to_swap_in, |
| 213 | + scheduler_outputs.blocks_to_swap_out, |
| 214 | + scheduler_outputs.blocks_to_copy) |
222 | 215 | else:
|
223 | 216 | output = []
|
224 | 217 |
|
@@ -268,37 +261,8 @@ async def add_request_async(
|
268 | 261 | lora_request=lora_request,
|
269 | 262 | )
|
270 | 263 |
|
271 |
| - async def _run_workers_async( |
272 |
| - self, |
273 |
| - method: str, |
274 |
| - *args, |
275 |
| - driver_args: Optional[List[Any]] = None, |
276 |
| - driver_kwargs: Optional[Dict[str, Any]] = None, |
277 |
| - **kwargs, |
278 |
| - ) -> Any: |
279 |
| - """Runs the given method on all workers.""" |
280 |
| - coros = [] |
281 |
| - |
282 |
| - if driver_args is None: |
283 |
| - driver_args = args |
284 |
| - if driver_kwargs is None: |
285 |
| - driver_kwargs = kwargs |
286 |
| - |
287 |
| - # Run the driver worker asynchronously. |
288 |
| - driver_executor = getattr(self.driver_worker, method) |
289 |
| - coros.append(asyncio.get_event_loop().run_in_executor( |
290 |
| - None, partial(driver_executor, *driver_args, **driver_kwargs))) |
291 |
| - |
292 |
| - # Run the ray workers asynchronously. |
293 |
| - for worker in self.workers: |
294 |
| - coros.append(worker.execute_method.remote(method, *args, **kwargs)) |
295 |
| - |
296 |
| - all_outputs = await asyncio.gather(*coros) |
297 |
| - return all_outputs |
298 |
| - |
299 |
| - async def check_health_async(self): |
300 |
| - """Raises an error if engine is unhealthy.""" |
301 |
| - self._check_if_any_actor_is_dead() |
| 264 | + async def check_health_async(self) -> None: |
| 265 | + self.model_executor.check_health() |
302 | 266 |
|
303 | 267 |
|
304 | 268 | class AsyncLLMEngine:
|
@@ -353,6 +317,34 @@ def __init__(self,
|
353 | 317 | self._request_tracker: Optional[RequestTracker] = None
|
354 | 318 | self._errored_with: Optional[BaseException] = None
|
355 | 319 |
|
| 320 | + @classmethod |
| 321 | + def from_engine_args(cls, |
| 322 | + engine_args: AsyncEngineArgs, |
| 323 | + start_engine_loop: bool = True) -> "AsyncLLMEngine": |
| 324 | + """Creates an async LLM engine from the engine arguments.""" |
| 325 | + # Create the engine configs. |
| 326 | + engine_configs = engine_args.create_engine_configs() |
| 327 | + parallel_config = engine_configs[2] |
| 328 | + if parallel_config.worker_use_ray or engine_args.engine_use_ray: |
| 329 | + initialize_ray_cluster(parallel_config) |
| 330 | + from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync |
| 331 | + executor_class = RayGPUExecutorAsync |
| 332 | + else: |
| 333 | + assert parallel_config.world_size == 1, ( |
| 334 | + "Ray is required if parallel_config.world_size > 1.") |
| 335 | + from vllm.executor.gpu_executor import GPUExecutorAsync |
| 336 | + executor_class = GPUExecutorAsync |
| 337 | + # Create the async LLM engine. |
| 338 | + engine = cls(parallel_config.worker_use_ray, |
| 339 | + engine_args.engine_use_ray, |
| 340 | + *engine_configs, |
| 341 | + executor_class, |
| 342 | + log_requests=not engine_args.disable_log_requests, |
| 343 | + log_stats=not engine_args.disable_log_stats, |
| 344 | + max_log_len=engine_args.max_log_len, |
| 345 | + start_engine_loop=start_engine_loop) |
| 346 | + return engine |
| 347 | + |
356 | 348 | @property
|
357 | 349 | def is_running(self) -> bool:
|
358 | 350 | return (self.background_loop is not None
|
@@ -670,35 +662,13 @@ async def get_model_config(self) -> ModelConfig:
|
670 | 662 | else:
|
671 | 663 | return self.engine.get_model_config()
|
672 | 664 |
|
673 |
| - @classmethod |
674 |
| - def from_engine_args(cls, |
675 |
| - engine_args: AsyncEngineArgs, |
676 |
| - start_engine_loop: bool = True) -> "AsyncLLMEngine": |
677 |
| - """Creates an async LLM engine from the engine arguments.""" |
678 |
| - # Create the engine configs. |
679 |
| - engine_configs = engine_args.create_engine_configs() |
680 |
| - parallel_config = engine_configs[2] |
681 |
| - # Initialize the cluster. |
682 |
| - placement_group = initialize_cluster(parallel_config, |
683 |
| - engine_args.engine_use_ray) |
684 |
| - # Create the async LLM engine. |
685 |
| - engine = cls(parallel_config.worker_use_ray, |
686 |
| - engine_args.engine_use_ray, |
687 |
| - *engine_configs, |
688 |
| - placement_group, |
689 |
| - log_requests=not engine_args.disable_log_requests, |
690 |
| - log_stats=not engine_args.disable_log_stats, |
691 |
| - max_log_len=engine_args.max_log_len, |
692 |
| - start_engine_loop=start_engine_loop) |
693 |
| - return engine |
694 |
| - |
695 | 665 | async def do_log_stats(self) -> None:
|
696 | 666 | if self.engine_use_ray:
|
697 | 667 | await self.engine.do_log_stats.remote()
|
698 | 668 | else:
|
699 | 669 | self.engine.do_log_stats()
|
700 | 670 |
|
701 |
| - async def check_health(self): |
| 671 | + async def check_health(self) -> None: |
702 | 672 | """Raises an error if engine is unhealthy."""
|
703 | 673 | t = time.perf_counter()
|
704 | 674 | logger.debug("Starting health check...")
|
|
0 commit comments