Skip to content

[core] platform agnostic executor via collective_rpc #11256

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 112 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
112 commits
Select commit Hold shift + click to select a range
88b4d76
remove all args
youkaichao Dec 7, 2024
5998534
stash
youkaichao Dec 7, 2024
f657ce0
Merge branch 'main' into remove_allargs
youkaichao Dec 7, 2024
1b5934f
fix
youkaichao Dec 7, 2024
c8bbd64
Merge branch 'main' into remove_allargs
youkaichao Dec 16, 2024
c4f489a
add rank during init
youkaichao Dec 17, 2024
ce4699f
update env var
youkaichao Dec 17, 2024
bae852e
stash
youkaichao Dec 17, 2024
fbe9cb4
stash
youkaichao Dec 17, 2024
959939b
draft version
youkaichao Dec 17, 2024
8fe836e
fix spec decode
youkaichao Dec 17, 2024
6b30aef
fix
youkaichao Dec 17, 2024
d19d0ba
fix
youkaichao Dec 17, 2024
9e66c7f
fix
youkaichao Dec 17, 2024
bea43de
fix
youkaichao Dec 17, 2024
57d668d
fix
youkaichao Dec 17, 2024
d258f27
fix
youkaichao Dec 17, 2024
76749e0
fix
youkaichao Dec 17, 2024
28e4eae
fix
youkaichao Dec 17, 2024
c6f2cd0
fix
youkaichao Dec 17, 2024
49e5e27
rename
youkaichao Dec 17, 2024
8a6a629
use collective_rpc
youkaichao Dec 17, 2024
bb01b55
stash
youkaichao Dec 17, 2024
df43679
gpu executor
youkaichao Dec 17, 2024
02f1785
gpu executor
youkaichao Dec 17, 2024
6772227
fix kwargs
youkaichao Dec 17, 2024
d3becd3
device name
youkaichao Dec 17, 2024
6d59bbb
tests
youkaichao Dec 17, 2024
a458587
fix
youkaichao Dec 17, 2024
a1f99bf
fix
youkaichao Dec 17, 2024
e7c0ac3
fix
youkaichao Dec 17, 2024
1e3e4e0
fix
youkaichao Dec 17, 2024
fb0407a
fix tests
youkaichao Dec 17, 2024
35f7cf8
refine
youkaichao Dec 17, 2024
af82c4e
refine
youkaichao Dec 17, 2024
463e407
refactor xpu
youkaichao Dec 17, 2024
3a7b204
rename
youkaichao Dec 17, 2024
ccedbf5
rename
youkaichao Dec 17, 2024
dc61935
fix
youkaichao Dec 17, 2024
ef11071
fix
youkaichao Dec 17, 2024
0a49754
hpu
youkaichao Dec 17, 2024
c537b6c
hpu
youkaichao Dec 17, 2024
a71d9c2
remove gpu executor
youkaichao Dec 17, 2024
ca2586d
remove neuron executor
youkaichao Dec 17, 2024
511adb6
openvino
youkaichao Dec 17, 2024
4add989
openvino
youkaichao Dec 17, 2024
0b14f7e
tpu
youkaichao Dec 17, 2024
f853f80
cpu
youkaichao Dec 17, 2024
57692ab
cuda
youkaichao Dec 17, 2024
440a987
hpu and xpu
youkaichao Dec 17, 2024
9976703
ray
youkaichao Dec 17, 2024
096b1ba
_get_executor_cls
youkaichao Dec 17, 2024
9c2d166
tod
youkaichao Dec 17, 2024
7feee81
Merge branch 'main' into remove_allargs
youkaichao Dec 17, 2024
31d59b2
ray
youkaichao Dec 17, 2024
368271d
fix
youkaichao Dec 17, 2024
a0d1293
rename
youkaichao Dec 17, 2024
a133bf9
fix
youkaichao Dec 17, 2024
dae5023
fix
youkaichao Dec 17, 2024
d57de3b
revert device_name change
youkaichao Jan 13, 2025
7c7364e
Merge branch 'main' into executor_refactor
youkaichao Jan 13, 2025
ff31dfb
fix linter
youkaichao Jan 13, 2025
3402c32
add RayWorkerMetaData
youkaichao Jan 13, 2025
7275d5a
add comments
youkaichao Jan 13, 2025
59c95b2
add comments
youkaichao Jan 13, 2025
2366f0f
fix linter
youkaichao Jan 13, 2025
53cb755
format
youkaichao Jan 13, 2025
cacb0cb
format
youkaichao Jan 13, 2025
f56306b
remove extra_execute_model_run_workers_kwargs
youkaichao Jan 13, 2025
ebc6f22
add check_health
youkaichao Jan 13, 2025
786dfcf
add resources
youkaichao Jan 13, 2025
0a6f6ec
fix
youkaichao Jan 13, 2025
d26f3f0
fix
youkaichao Jan 13, 2025
aaa57c7
format
youkaichao Jan 13, 2025
1701db6
format
youkaichao Jan 13, 2025
93bc101
Merge branch 'main' into executor_refactor
youkaichao Jan 13, 2025
f01cb8f
Merge branch 'main' into executor_refactor
youkaichao Jan 13, 2025
c1698dd
use device_control_env_var
youkaichao Jan 13, 2025
8f1ea58
call method directly
youkaichao Jan 13, 2025
78a1a2e
simplify code
youkaichao Jan 13, 2025
493d34e
simplify code
youkaichao Jan 13, 2025
e71fc22
fix tests
youkaichao Jan 13, 2025
9e2fc3f
fix ray args
youkaichao Jan 14, 2025
50a43bd
fix xpu
youkaichao Jan 14, 2025
1690df5
fix spec decode
youkaichao Jan 14, 2025
0e02792
use ray
youkaichao Jan 14, 2025
cd7d24e
Merge branch 'main' into executor_refactor
youkaichao Jan 14, 2025
dcd6735
fix hpu
youkaichao Jan 14, 2025
9b31403
fix xpu
youkaichao Jan 14, 2025
058392f
fix xpu
youkaichao Jan 14, 2025
0be6c88
fix neuron
youkaichao Jan 14, 2025
f5cf4e9
fix neuron
youkaichao Jan 14, 2025
f56f2fb
fix neuron?
youkaichao Jan 14, 2025
b76c0af
fix neuron
youkaichao Jan 14, 2025
8394686
fix v1 code
youkaichao Jan 14, 2025
8d2e536
fix ray
youkaichao Jan 14, 2025
8d77f68
fix neuron
youkaichao Jan 14, 2025
7d6673c
fix ray
youkaichao Jan 14, 2025
d47d6ea
unify initialize_ray_cluster
youkaichao Jan 14, 2025
e6b88c4
fix neuron
youkaichao Jan 14, 2025
5db5dc8
unify signature
youkaichao Jan 14, 2025
37d37a6
fix format
youkaichao Jan 14, 2025
9f411d9
lint
youkaichao Jan 14, 2025
ccaad6b
rename to init device
youkaichao Jan 14, 2025
b156cdd
lint
youkaichao Jan 14, 2025
4e9b2fb
fix v1 compatibility
youkaichao Jan 14, 2025
42cfe34
fix neuron
youkaichao Jan 14, 2025
bb41045
fix neuron
youkaichao Jan 14, 2025
d1ab209
fix neuron
youkaichao Jan 14, 2025
5c3c20b
fix neuron
youkaichao Jan 14, 2025
27ec870
fix ray executor for v1
youkaichao Jan 14, 2025
378f608
fix shard state tests
youkaichao Jan 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 12 additions & 16 deletions tests/engine/test_custom_executor.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,34 @@
import asyncio
import os
from typing import Any, Dict, List, Optional, Tuple

import pytest

from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.llm_engine import LLMEngine
from vllm.executor.gpu_executor import GPUExecutor, GPUExecutorAsync
from vllm.executor.uniproc_executor import UniProcExecutor
from vllm.sampling_params import SamplingParams


class Mock:
...


class CustomGPUExecutor(GPUExecutor):
class CustomUniExecutor(UniProcExecutor):

def execute_model(self, *args, **kwargs):
def collective_rpc(self,
method: str,
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict] = None) -> List[Any]:
# Drop marker to show that this was ran
with open(".marker", "w"):
...
return super().execute_model(*args, **kwargs)
return super().collective_rpc(method, timeout, args, kwargs)


class CustomGPUExecutorAsync(GPUExecutorAsync):

async def execute_model_async(self, *args, **kwargs):
with open(".marker", "w"):
...
return await super().execute_model_async(*args, **kwargs)
CustomUniExecutorAsync = CustomUniExecutor


@pytest.mark.parametrize("model", ["facebook/opt-125m"])
Expand All @@ -41,10 +41,6 @@ def test_custom_executor_type_checking(model):
engine_args = AsyncEngineArgs(model=model,
distributed_executor_backend=Mock)
AsyncLLMEngine.from_engine_args(engine_args)
with pytest.raises(TypeError):
engine_args = AsyncEngineArgs(
model=model, distributed_executor_backend=CustomGPUExecutor)
AsyncLLMEngine.from_engine_args(engine_args)


@pytest.mark.parametrize("model", ["facebook/opt-125m"])
Expand All @@ -55,7 +51,7 @@ def test_custom_executor(model, tmp_path):
assert not os.path.exists(".marker")

engine_args = EngineArgs(
model=model, distributed_executor_backend=CustomGPUExecutor)
model=model, distributed_executor_backend=CustomUniExecutor)
engine = LLMEngine.from_engine_args(engine_args)
sampling_params = SamplingParams(max_tokens=1)

Expand All @@ -75,7 +71,7 @@ def test_custom_executor_async(model, tmp_path):
assert not os.path.exists(".marker")

engine_args = AsyncEngineArgs(
model=model, distributed_executor_backend=CustomGPUExecutorAsync)
model=model, distributed_executor_backend=CustomUniExecutorAsync)
engine = AsyncLLMEngine.from_engine_args(engine_args)
sampling_params = SamplingParams(max_tokens=1)

Expand Down
12 changes: 6 additions & 6 deletions tests/engine/test_multiproc_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,15 @@

import pytest

from vllm.config import VllmConfig
from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
ResultHandler, WorkerMonitor)
from vllm.worker.worker_base import WorkerWrapperBase


class DummyWorker:
class DummyWorkerWrapper(WorkerWrapperBase):
"""Dummy version of vllm.worker.worker.Worker"""

def __init__(self, rank: int):
self.rank = rank

def worker_method(self, worker_input: Any) -> Tuple[int, Any]:
sleep(0.05)

Expand All @@ -28,9 +27,10 @@ def worker_method(self, worker_input: Any) -> Tuple[int, Any]:

def _start_workers() -> Tuple[List[ProcessWorkerWrapper], WorkerMonitor]:
result_handler = ResultHandler()
vllm_config = VllmConfig()
workers = [
ProcessWorkerWrapper(result_handler, partial(DummyWorker, rank=rank))
for rank in range(8)
ProcessWorkerWrapper(result_handler, DummyWorkerWrapper, vllm_config,
rank) for rank in range(8)
]

worker_monitor = WorkerMonitor(workers, result_handler)
Expand Down
6 changes: 5 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import socket
from typing import AsyncIterator, Tuple
from unittest.mock import patch

import pytest
import torch
Expand Down Expand Up @@ -390,7 +391,10 @@ def test_bind_kv_cache_encoder_decoder():


def test_bind_kv_cache_pp():
cfg = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=2))
with patch("vllm.utils.cuda_device_count_stateless", lambda: 2):
# this test runs with 1 GPU, but we simulate 2 GPUs
cfg = VllmConfig(
parallel_config=ParallelConfig(pipeline_parallel_size=2))
with set_current_vllm_config(cfg):
from vllm.attention import Attention

Expand Down
12 changes: 8 additions & 4 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1294,8 +1294,11 @@ def __post_init__(self) -> None:
from vllm.executor import ray_utils
backend = "mp"
ray_found = ray_utils.ray_is_available()
if (current_platform.is_cuda()
and cuda_device_count_stateless() < self.world_size):
if current_platform.is_neuron():
# neuron uses single process to control multiple devices
backend = "uni"
elif (current_platform.is_cuda()
and cuda_device_count_stateless() < self.world_size):
if not ray_found:
raise ValueError("Unable to load Ray which is "
"required for multi-node inference, "
Expand Down Expand Up @@ -1328,13 +1331,14 @@ def _verify_args(self) -> None:
from vllm.executor.executor_base import ExecutorBase
from vllm.platforms import current_platform
if self.distributed_executor_backend not in (
"ray", "mp", None) and not (isinstance(
"ray", "mp", "uni", None) and not (isinstance(
self.distributed_executor_backend, type) and issubclass(
self.distributed_executor_backend, ExecutorBase)):
raise ValueError(
"Unrecognized distributed executor backend "
f"{self.distributed_executor_backend}. Supported "
"values are 'ray', 'mp' or custom ExecutorBase subclass.")
"values are 'ray', 'mp' 'uni', or custom ExecutorBase"
" subclass.")
if self.use_ray:
from vllm.executor import ray_utils
ray_utils.assert_ray_available()
Expand Down
6 changes: 4 additions & 2 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,12 +862,14 @@ def init_model_parallel_group(
) -> GroupCoordinator:
if use_custom_allreduce is None:
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
from vllm.platforms import current_platform
return GroupCoordinator(
group_ranks=group_ranks,
local_rank=local_rank,
torch_distributed_backend=backend,
use_pynccl=True,
use_custom_allreduce=use_custom_allreduce,
use_pynccl=current_platform.is_cuda_alike(),
use_custom_allreduce=current_platform.is_cuda_alike()
and use_custom_allreduce,
use_tpu_communicator=True,
use_hpu_communicator=True,
use_xpu_communicator=True,
Expand Down
88 changes: 6 additions & 82 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@
from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState
from vllm.engine.metrics_types import StatLoggerBase
from vllm.engine.protocol import EngineClient
from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.gpu_executor import GPUExecutorAsync
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.executor.executor_base import ExecutorBase
from vllm.inputs import PromptType
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
Expand Down Expand Up @@ -620,69 +618,9 @@ def __del__(self):
rt.new_requests_event.set()

@classmethod
def _get_executor_cls(
cls, engine_config: VllmConfig) -> Type[ExecutorAsyncBase]:
distributed_executor_backend = (
engine_config.parallel_config.distributed_executor_backend)
if isinstance(distributed_executor_backend, type):
if not issubclass(distributed_executor_backend, ExecutorAsyncBase):
raise TypeError(
"distributed_executor_backend must be a subclass of "
f"ExecutorAsyncBase. Got {distributed_executor_backend}.")
executor_class = distributed_executor_backend
elif engine_config.device_config.device_type == "neuron":
from vllm.executor.neuron_executor import NeuronExecutorAsync
executor_class = NeuronExecutorAsync
elif engine_config.device_config.device_type == "tpu":
if distributed_executor_backend == "ray":
from vllm.executor.ray_tpu_executor import RayTPUExecutorAsync
executor_class = RayTPUExecutorAsync
else:
assert distributed_executor_backend is None
from vllm.executor.tpu_executor import TPUExecutorAsync
executor_class = TPUExecutorAsync
elif engine_config.device_config.device_type == "cpu":
from vllm.executor.cpu_executor import CPUExecutorAsync
executor_class = CPUExecutorAsync
elif engine_config.device_config.device_type == "hpu":
if distributed_executor_backend == "ray":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_hpu_executor import RayHPUExecutorAsync
executor_class = RayHPUExecutorAsync
else:
from vllm.executor.hpu_executor import HPUExecutorAsync
executor_class = HPUExecutorAsync
elif engine_config.device_config.device_type == "openvino":
assert distributed_executor_backend is None, (
"Distributed execution is not supported with "
"the OpenVINO backend.")
from vllm.executor.openvino_executor import OpenVINOExecutorAsync
executor_class = OpenVINOExecutorAsync
elif engine_config.device_config.device_type == "xpu":
if distributed_executor_backend is None:
from vllm.executor.xpu_executor import XPUExecutorAsync
executor_class = XPUExecutorAsync
elif distributed_executor_backend == "ray":
from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync
executor_class = RayXPUExecutorAsync
elif distributed_executor_backend == "mp":
from vllm.executor.multiproc_xpu_executor import (
MultiprocessingXPUExecutorAsync)
executor_class = MultiprocessingXPUExecutorAsync
else:
raise RuntimeError(
"Not supported distributed execution model on XPU device.")
elif distributed_executor_backend == "ray":
from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
executor_class = RayGPUExecutorAsync
elif distributed_executor_backend == "mp":
from vllm.executor.multiproc_gpu_executor import (
MultiprocessingGPUExecutorAsync)
executor_class = MultiprocessingGPUExecutorAsync
else:
from vllm.executor.gpu_executor import GPUExecutorAsync
executor_class = GPUExecutorAsync
return executor_class
def _get_executor_cls(cls,
engine_config: VllmConfig) -> Type[ExecutorBase]:
return LLMEngine._get_executor_cls(engine_config)

@classmethod
def from_engine_args(
Expand All @@ -700,9 +638,6 @@ def from_engine_args(

executor_class = cls._get_executor_cls(engine_config)

if executor_class.uses_ray:
initialize_ray_cluster(engine_config.parallel_config)

# Create the async LLM engine.
engine = cls(
vllm_config=engine_config,
Expand Down Expand Up @@ -1242,23 +1177,12 @@ def remove_logger(self, logger_name: str) -> None:
self.engine.remove_logger(logger_name=logger_name)

async def start_profile(self) -> None:
# using type instead of isinstance to check to avoid capturing
# inherited classes
if type(self.engine.model_executor) == GPUExecutorAsync: # noqa: E721
self.engine.model_executor.start_profile()
else:
self.engine.model_executor._run_workers("start_profile")
self.engine.start_profile()

async def stop_profile(self) -> None:
# using type instead of isinstance to check to avoid capturing
# inherited classes
if type(self.engine.model_executor) == GPUExecutorAsync: # noqa: E721
self.engine.model_executor.stop_profile()
else:
self.engine.model_executor._run_workers("stop_profile")
self.engine.stop_profile()

async def add_lora(self, lora_request: LoRARequest) -> None:
"""Load a new LoRA adapter into the engine for future requests."""
self.engine.add_lora(lora_request)


Expand Down
Loading
Loading