Skip to content

Commit eb761e7

Browse files
youkaichaoIsotr0py
authored andcommitted
[core] platform agnostic executor via collective_rpc (vllm-project#11256)
Signed-off-by: youkaichao <[email protected]> Signed-off-by: Isotr0py <[email protected]>
1 parent b2d992f commit eb761e7

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+852
-2642
lines changed

tests/engine/test_custom_executor.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,34 @@
11
import asyncio
22
import os
3+
from typing import Any, Dict, List, Optional, Tuple
34

45
import pytest
56

67
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
78
from vllm.engine.async_llm_engine import AsyncLLMEngine
89
from vllm.engine.llm_engine import LLMEngine
9-
from vllm.executor.gpu_executor import GPUExecutor, GPUExecutorAsync
10+
from vllm.executor.uniproc_executor import UniProcExecutor
1011
from vllm.sampling_params import SamplingParams
1112

1213

1314
class Mock:
1415
...
1516

1617

17-
class CustomGPUExecutor(GPUExecutor):
18+
class CustomUniExecutor(UniProcExecutor):
1819

19-
def execute_model(self, *args, **kwargs):
20+
def collective_rpc(self,
21+
method: str,
22+
timeout: Optional[float] = None,
23+
args: Tuple = (),
24+
kwargs: Optional[Dict] = None) -> List[Any]:
2025
# Drop marker to show that this was ran
2126
with open(".marker", "w"):
2227
...
23-
return super().execute_model(*args, **kwargs)
28+
return super().collective_rpc(method, timeout, args, kwargs)
2429

2530

26-
class CustomGPUExecutorAsync(GPUExecutorAsync):
27-
28-
async def execute_model_async(self, *args, **kwargs):
29-
with open(".marker", "w"):
30-
...
31-
return await super().execute_model_async(*args, **kwargs)
31+
CustomUniExecutorAsync = CustomUniExecutor
3232

3333

3434
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
@@ -41,10 +41,6 @@ def test_custom_executor_type_checking(model):
4141
engine_args = AsyncEngineArgs(model=model,
4242
distributed_executor_backend=Mock)
4343
AsyncLLMEngine.from_engine_args(engine_args)
44-
with pytest.raises(TypeError):
45-
engine_args = AsyncEngineArgs(
46-
model=model, distributed_executor_backend=CustomGPUExecutor)
47-
AsyncLLMEngine.from_engine_args(engine_args)
4844

4945

5046
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
@@ -55,7 +51,7 @@ def test_custom_executor(model, tmp_path):
5551
assert not os.path.exists(".marker")
5652

5753
engine_args = EngineArgs(
58-
model=model, distributed_executor_backend=CustomGPUExecutor)
54+
model=model, distributed_executor_backend=CustomUniExecutor)
5955
engine = LLMEngine.from_engine_args(engine_args)
6056
sampling_params = SamplingParams(max_tokens=1)
6157

@@ -75,7 +71,7 @@ def test_custom_executor_async(model, tmp_path):
7571
assert not os.path.exists(".marker")
7672

7773
engine_args = AsyncEngineArgs(
78-
model=model, distributed_executor_backend=CustomGPUExecutorAsync)
74+
model=model, distributed_executor_backend=CustomUniExecutorAsync)
7975
engine = AsyncLLMEngine.from_engine_args(engine_args)
8076
sampling_params = SamplingParams(max_tokens=1)
8177

tests/engine/test_multiproc_workers.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,15 @@
66

77
import pytest
88

9+
from vllm.config import VllmConfig
910
from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
1011
ResultHandler, WorkerMonitor)
12+
from vllm.worker.worker_base import WorkerWrapperBase
1113

1214

13-
class DummyWorker:
15+
class DummyWorkerWrapper(WorkerWrapperBase):
1416
"""Dummy version of vllm.worker.worker.Worker"""
1517

16-
def __init__(self, rank: int):
17-
self.rank = rank
18-
1918
def worker_method(self, worker_input: Any) -> Tuple[int, Any]:
2019
sleep(0.05)
2120

@@ -28,9 +27,10 @@ def worker_method(self, worker_input: Any) -> Tuple[int, Any]:
2827

2928
def _start_workers() -> Tuple[List[ProcessWorkerWrapper], WorkerMonitor]:
3029
result_handler = ResultHandler()
30+
vllm_config = VllmConfig()
3131
workers = [
32-
ProcessWorkerWrapper(result_handler, partial(DummyWorker, rank=rank))
33-
for rank in range(8)
32+
ProcessWorkerWrapper(result_handler, DummyWorkerWrapper, vllm_config,
33+
rank) for rank in range(8)
3434
]
3535

3636
worker_monitor = WorkerMonitor(workers, result_handler)

tests/test_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
import socket
44
from typing import AsyncIterator, Tuple
5+
from unittest.mock import patch
56

67
import pytest
78
import torch
@@ -390,7 +391,10 @@ def test_bind_kv_cache_encoder_decoder():
390391

391392

392393
def test_bind_kv_cache_pp():
393-
cfg = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=2))
394+
with patch("vllm.utils.cuda_device_count_stateless", lambda: 2):
395+
# this test runs with 1 GPU, but we simulate 2 GPUs
396+
cfg = VllmConfig(
397+
parallel_config=ParallelConfig(pipeline_parallel_size=2))
394398
with set_current_vllm_config(cfg):
395399
from vllm.attention import Attention
396400

vllm/config.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1294,8 +1294,11 @@ def __post_init__(self) -> None:
12941294
from vllm.executor import ray_utils
12951295
backend = "mp"
12961296
ray_found = ray_utils.ray_is_available()
1297-
if (current_platform.is_cuda()
1298-
and cuda_device_count_stateless() < self.world_size):
1297+
if current_platform.is_neuron():
1298+
# neuron uses single process to control multiple devices
1299+
backend = "uni"
1300+
elif (current_platform.is_cuda()
1301+
and cuda_device_count_stateless() < self.world_size):
12991302
if not ray_found:
13001303
raise ValueError("Unable to load Ray which is "
13011304
"required for multi-node inference, "
@@ -1328,13 +1331,14 @@ def _verify_args(self) -> None:
13281331
from vllm.executor.executor_base import ExecutorBase
13291332
from vllm.platforms import current_platform
13301333
if self.distributed_executor_backend not in (
1331-
"ray", "mp", None) and not (isinstance(
1334+
"ray", "mp", "uni", None) and not (isinstance(
13321335
self.distributed_executor_backend, type) and issubclass(
13331336
self.distributed_executor_backend, ExecutorBase)):
13341337
raise ValueError(
13351338
"Unrecognized distributed executor backend "
13361339
f"{self.distributed_executor_backend}. Supported "
1337-
"values are 'ray', 'mp' or custom ExecutorBase subclass.")
1340+
"values are 'ray', 'mp' 'uni', or custom ExecutorBase"
1341+
" subclass.")
13381342
if self.use_ray:
13391343
from vllm.executor import ray_utils
13401344
ray_utils.assert_ray_available()

vllm/distributed/parallel_state.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -862,12 +862,14 @@ def init_model_parallel_group(
862862
) -> GroupCoordinator:
863863
if use_custom_allreduce is None:
864864
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
865+
from vllm.platforms import current_platform
865866
return GroupCoordinator(
866867
group_ranks=group_ranks,
867868
local_rank=local_rank,
868869
torch_distributed_backend=backend,
869-
use_pynccl=True,
870-
use_custom_allreduce=use_custom_allreduce,
870+
use_pynccl=current_platform.is_cuda_alike(),
871+
use_custom_allreduce=current_platform.is_cuda_alike()
872+
and use_custom_allreduce,
871873
use_tpu_communicator=True,
872874
use_hpu_communicator=True,
873875
use_xpu_communicator=True,

vllm/engine/async_llm_engine.py

Lines changed: 6 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,7 @@
1818
from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState
1919
from vllm.engine.metrics_types import StatLoggerBase
2020
from vllm.engine.protocol import EngineClient
21-
from vllm.executor.executor_base import ExecutorAsyncBase
22-
from vllm.executor.gpu_executor import GPUExecutorAsync
23-
from vllm.executor.ray_utils import initialize_ray_cluster
21+
from vllm.executor.executor_base import ExecutorBase
2422
from vllm.inputs import PromptType
2523
from vllm.inputs.preprocess import InputPreprocessor
2624
from vllm.logger import init_logger
@@ -620,69 +618,9 @@ def __del__(self):
620618
rt.new_requests_event.set()
621619

622620
@classmethod
623-
def _get_executor_cls(
624-
cls, engine_config: VllmConfig) -> Type[ExecutorAsyncBase]:
625-
distributed_executor_backend = (
626-
engine_config.parallel_config.distributed_executor_backend)
627-
if isinstance(distributed_executor_backend, type):
628-
if not issubclass(distributed_executor_backend, ExecutorAsyncBase):
629-
raise TypeError(
630-
"distributed_executor_backend must be a subclass of "
631-
f"ExecutorAsyncBase. Got {distributed_executor_backend}.")
632-
executor_class = distributed_executor_backend
633-
elif engine_config.device_config.device_type == "neuron":
634-
from vllm.executor.neuron_executor import NeuronExecutorAsync
635-
executor_class = NeuronExecutorAsync
636-
elif engine_config.device_config.device_type == "tpu":
637-
if distributed_executor_backend == "ray":
638-
from vllm.executor.ray_tpu_executor import RayTPUExecutorAsync
639-
executor_class = RayTPUExecutorAsync
640-
else:
641-
assert distributed_executor_backend is None
642-
from vllm.executor.tpu_executor import TPUExecutorAsync
643-
executor_class = TPUExecutorAsync
644-
elif engine_config.device_config.device_type == "cpu":
645-
from vllm.executor.cpu_executor import CPUExecutorAsync
646-
executor_class = CPUExecutorAsync
647-
elif engine_config.device_config.device_type == "hpu":
648-
if distributed_executor_backend == "ray":
649-
initialize_ray_cluster(engine_config.parallel_config)
650-
from vllm.executor.ray_hpu_executor import RayHPUExecutorAsync
651-
executor_class = RayHPUExecutorAsync
652-
else:
653-
from vllm.executor.hpu_executor import HPUExecutorAsync
654-
executor_class = HPUExecutorAsync
655-
elif engine_config.device_config.device_type == "openvino":
656-
assert distributed_executor_backend is None, (
657-
"Distributed execution is not supported with "
658-
"the OpenVINO backend.")
659-
from vllm.executor.openvino_executor import OpenVINOExecutorAsync
660-
executor_class = OpenVINOExecutorAsync
661-
elif engine_config.device_config.device_type == "xpu":
662-
if distributed_executor_backend is None:
663-
from vllm.executor.xpu_executor import XPUExecutorAsync
664-
executor_class = XPUExecutorAsync
665-
elif distributed_executor_backend == "ray":
666-
from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync
667-
executor_class = RayXPUExecutorAsync
668-
elif distributed_executor_backend == "mp":
669-
from vllm.executor.multiproc_xpu_executor import (
670-
MultiprocessingXPUExecutorAsync)
671-
executor_class = MultiprocessingXPUExecutorAsync
672-
else:
673-
raise RuntimeError(
674-
"Not supported distributed execution model on XPU device.")
675-
elif distributed_executor_backend == "ray":
676-
from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
677-
executor_class = RayGPUExecutorAsync
678-
elif distributed_executor_backend == "mp":
679-
from vllm.executor.multiproc_gpu_executor import (
680-
MultiprocessingGPUExecutorAsync)
681-
executor_class = MultiprocessingGPUExecutorAsync
682-
else:
683-
from vllm.executor.gpu_executor import GPUExecutorAsync
684-
executor_class = GPUExecutorAsync
685-
return executor_class
621+
def _get_executor_cls(cls,
622+
engine_config: VllmConfig) -> Type[ExecutorBase]:
623+
return LLMEngine._get_executor_cls(engine_config)
686624

687625
@classmethod
688626
def from_engine_args(
@@ -700,9 +638,6 @@ def from_engine_args(
700638

701639
executor_class = cls._get_executor_cls(engine_config)
702640

703-
if executor_class.uses_ray:
704-
initialize_ray_cluster(engine_config.parallel_config)
705-
706641
# Create the async LLM engine.
707642
engine = cls(
708643
vllm_config=engine_config,
@@ -1242,23 +1177,12 @@ def remove_logger(self, logger_name: str) -> None:
12421177
self.engine.remove_logger(logger_name=logger_name)
12431178

12441179
async def start_profile(self) -> None:
1245-
# using type instead of isinstance to check to avoid capturing
1246-
# inherited classes
1247-
if type(self.engine.model_executor) == GPUExecutorAsync: # noqa: E721
1248-
self.engine.model_executor.start_profile()
1249-
else:
1250-
self.engine.model_executor._run_workers("start_profile")
1180+
self.engine.start_profile()
12511181

12521182
async def stop_profile(self) -> None:
1253-
# using type instead of isinstance to check to avoid capturing
1254-
# inherited classes
1255-
if type(self.engine.model_executor) == GPUExecutorAsync: # noqa: E721
1256-
self.engine.model_executor.stop_profile()
1257-
else:
1258-
self.engine.model_executor._run_workers("stop_profile")
1183+
self.engine.stop_profile()
12591184

12601185
async def add_lora(self, lora_request: LoRARequest) -> None:
1261-
"""Load a new LoRA adapter into the engine for future requests."""
12621186
self.engine.add_lora(lora_request)
12631187

12641188

0 commit comments

Comments
 (0)