Skip to content

Commit 4c92270

Browse files
authored
Add distributed model executor abstraction (#3191)
1 parent 657061f commit 4c92270

13 files changed

+818
-509
lines changed

docs/source/dev/engine/llm_engine.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@ LLMEngine
22
=================================
33

44
.. autoclass:: vllm.engine.llm_engine.LLMEngine
5-
:members: add_request, abort_request, step, _init_cache
5+
:members: add_request, abort_request, step
66
:show-inheritance:

format.sh

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,13 +95,17 @@ echo 'vLLM yapf: Done'
9595
# echo 'vLLM mypy:'
9696
# mypy
9797

98+
CODESPELL_EXCLUDES=(
99+
'--skip' '*docs/source/_build/**'
100+
)
101+
98102
# check spelling of specified files
99103
spell_check() {
100104
codespell "$@"
101105
}
102106

103107
spell_check_all(){
104-
codespell --toml pyproject.toml
108+
codespell --toml pyproject.toml "${CODESPELL_EXCLUDES[@]}"
105109
}
106110

107111
# Spelling check of files that differ from main branch.
@@ -116,7 +120,7 @@ spell_check_changed() {
116120

117121
if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then
118122
git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs \
119-
codespell
123+
codespell "${CODESPELL_EXCLUDES[@]}"
120124
fi
121125
}
122126

tests/lora/conftest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,4 +152,5 @@ def get_model_patched(model_config, device_config, **kwargs):
152152
@pytest.fixture
153153
def llama_2_7b_model_extra_embeddings(
154154
llama_2_7b_engine_extra_embeddings) -> nn.Module:
155-
yield llama_2_7b_engine_extra_embeddings.driver_worker.model_runner.model
155+
yield (llama_2_7b_engine_extra_embeddings.model_executor.driver_worker.
156+
model_runner.model)

vllm/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
44
from vllm.engine.async_llm_engine import AsyncLLMEngine
55
from vllm.engine.llm_engine import LLMEngine
6-
from vllm.engine.ray_utils import initialize_cluster
6+
from vllm.engine.ray_utils import initialize_ray_cluster
77
from vllm.entrypoints.llm import LLM
88
from vllm.outputs import CompletionOutput, RequestOutput
99
from vllm.sampling_params import SamplingParams
@@ -19,5 +19,5 @@
1919
"EngineArgs",
2020
"AsyncLLMEngine",
2121
"AsyncEngineArgs",
22-
"initialize_cluster",
22+
"initialize_ray_cluster",
2323
]

vllm/config.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional, Union, ClassVar
1+
from typing import TYPE_CHECKING, Optional, Union, ClassVar
22
from dataclasses import dataclass
33
import os
44
from packaging.version import Version
@@ -10,6 +10,9 @@
1010
from vllm.transformers_utils.config import get_config
1111
from vllm.utils import get_cpu_memory, is_hip, is_neuron, get_nvcc_cuda_version
1212

13+
if TYPE_CHECKING:
14+
from ray.util.placement_group import PlacementGroup
15+
1316
logger = init_logger(__name__)
1417

1518
_GB = 1 << 30
@@ -397,6 +400,7 @@ def __init__(
397400
max_parallel_loading_workers: Optional[int] = None,
398401
disable_custom_all_reduce: bool = False,
399402
ray_workers_use_nsight: bool = False,
403+
placement_group: Optional["PlacementGroup"] = None,
400404
) -> None:
401405
self.pipeline_parallel_size = pipeline_parallel_size
402406
if is_neuron():
@@ -412,6 +416,7 @@ def __init__(
412416
self.max_parallel_loading_workers = max_parallel_loading_workers
413417
self.disable_custom_all_reduce = disable_custom_all_reduce
414418
self.ray_workers_use_nsight = ray_workers_use_nsight
419+
self.placement_group = placement_group
415420

416421
self.world_size = pipeline_parallel_size * self.tensor_parallel_size
417422
# Ray worker is not supported for Neuron backend.

vllm/engine/async_llm_engine.py

Lines changed: 38 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,16 @@
22
import os
33
import time
44
from functools import partial
5-
from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type,
6-
Union, AsyncIterator, Callable)
5+
from typing import (Callable, Dict, Iterable, List, Optional, Set, Tuple, Type,
6+
Union, AsyncIterator)
77

88
from transformers import PreTrainedTokenizer
99

1010
from vllm.lora.request import LoRARequest
1111
from vllm.config import ModelConfig
1212
from vllm.engine.arg_utils import AsyncEngineArgs
1313
from vllm.engine.llm_engine import LLMEngine
14-
from vllm.engine.ray_utils import initialize_cluster, ray
14+
from vllm.engine.ray_utils import initialize_ray_cluster, ray
1515
from vllm.logger import init_logger
1616
from vllm.outputs import RequestOutput
1717
from vllm.sampling_params import SamplingParams
@@ -208,17 +208,10 @@ async def step_async(self) -> List[RequestOutput]:
208208

209209
if not scheduler_outputs.is_empty():
210210
# Execute the model.
211-
all_outputs = await self._run_workers_async(
212-
"execute_model",
213-
driver_kwargs={
214-
"seq_group_metadata_list": seq_group_metadata_list,
215-
"blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in,
216-
"blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out,
217-
"blocks_to_copy": scheduler_outputs.blocks_to_copy,
218-
})
219-
220-
# Only the driver worker returns the sampling results.
221-
output = all_outputs[0]
211+
output = await self.model_executor.execute_model_async(
212+
seq_group_metadata_list, scheduler_outputs.blocks_to_swap_in,
213+
scheduler_outputs.blocks_to_swap_out,
214+
scheduler_outputs.blocks_to_copy)
222215
else:
223216
output = []
224217

@@ -268,37 +261,8 @@ async def add_request_async(
268261
lora_request=lora_request,
269262
)
270263

271-
async def _run_workers_async(
272-
self,
273-
method: str,
274-
*args,
275-
driver_args: Optional[List[Any]] = None,
276-
driver_kwargs: Optional[Dict[str, Any]] = None,
277-
**kwargs,
278-
) -> Any:
279-
"""Runs the given method on all workers."""
280-
coros = []
281-
282-
if driver_args is None:
283-
driver_args = args
284-
if driver_kwargs is None:
285-
driver_kwargs = kwargs
286-
287-
# Run the driver worker asynchronously.
288-
driver_executor = getattr(self.driver_worker, method)
289-
coros.append(asyncio.get_event_loop().run_in_executor(
290-
None, partial(driver_executor, *driver_args, **driver_kwargs)))
291-
292-
# Run the ray workers asynchronously.
293-
for worker in self.workers:
294-
coros.append(worker.execute_method.remote(method, *args, **kwargs))
295-
296-
all_outputs = await asyncio.gather(*coros)
297-
return all_outputs
298-
299-
async def check_health_async(self):
300-
"""Raises an error if engine is unhealthy."""
301-
self._check_if_any_actor_is_dead()
264+
async def check_health_async(self) -> None:
265+
self.model_executor.check_health()
302266

303267

304268
class AsyncLLMEngine:
@@ -353,6 +317,34 @@ def __init__(self,
353317
self._request_tracker: Optional[RequestTracker] = None
354318
self._errored_with: Optional[BaseException] = None
355319

320+
@classmethod
321+
def from_engine_args(cls,
322+
engine_args: AsyncEngineArgs,
323+
start_engine_loop: bool = True) -> "AsyncLLMEngine":
324+
"""Creates an async LLM engine from the engine arguments."""
325+
# Create the engine configs.
326+
engine_configs = engine_args.create_engine_configs()
327+
parallel_config = engine_configs[2]
328+
if parallel_config.worker_use_ray or engine_args.engine_use_ray:
329+
initialize_ray_cluster(parallel_config)
330+
from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
331+
executor_class = RayGPUExecutorAsync
332+
else:
333+
assert parallel_config.world_size == 1, (
334+
"Ray is required if parallel_config.world_size > 1.")
335+
from vllm.executor.gpu_executor import GPUExecutorAsync
336+
executor_class = GPUExecutorAsync
337+
# Create the async LLM engine.
338+
engine = cls(parallel_config.worker_use_ray,
339+
engine_args.engine_use_ray,
340+
*engine_configs,
341+
executor_class,
342+
log_requests=not engine_args.disable_log_requests,
343+
log_stats=not engine_args.disable_log_stats,
344+
max_log_len=engine_args.max_log_len,
345+
start_engine_loop=start_engine_loop)
346+
return engine
347+
356348
@property
357349
def is_running(self) -> bool:
358350
return (self.background_loop is not None
@@ -670,35 +662,13 @@ async def get_model_config(self) -> ModelConfig:
670662
else:
671663
return self.engine.get_model_config()
672664

673-
@classmethod
674-
def from_engine_args(cls,
675-
engine_args: AsyncEngineArgs,
676-
start_engine_loop: bool = True) -> "AsyncLLMEngine":
677-
"""Creates an async LLM engine from the engine arguments."""
678-
# Create the engine configs.
679-
engine_configs = engine_args.create_engine_configs()
680-
parallel_config = engine_configs[2]
681-
# Initialize the cluster.
682-
placement_group = initialize_cluster(parallel_config,
683-
engine_args.engine_use_ray)
684-
# Create the async LLM engine.
685-
engine = cls(parallel_config.worker_use_ray,
686-
engine_args.engine_use_ray,
687-
*engine_configs,
688-
placement_group,
689-
log_requests=not engine_args.disable_log_requests,
690-
log_stats=not engine_args.disable_log_stats,
691-
max_log_len=engine_args.max_log_len,
692-
start_engine_loop=start_engine_loop)
693-
return engine
694-
695665
async def do_log_stats(self) -> None:
696666
if self.engine_use_ray:
697667
await self.engine.do_log_stats.remote()
698668
else:
699669
self.engine.do_log_stats()
700670

701-
async def check_health(self):
671+
async def check_health(self) -> None:
702672
"""Raises an error if engine is unhealthy."""
703673
t = time.perf_counter()
704674
logger.debug("Starting health check...")

0 commit comments

Comments
 (0)