Skip to content

[Core] Multiprocessing executor for single-node multi-GPU [1/2] #4345

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 176 additions & 0 deletions tests/engine/test_multiproc_workers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
import asyncio
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from time import sleep
from typing import Any, List, Tuple

import pytest

from vllm.executor.multiproc_worker_utils import (LocalWorkerVllm,
ResultHandler, WorkerMonitor)


class DummyWorker:
"""Dummy version of vllm.worker.worker.Worker"""

def __init__(self, rank: int):
self.rank = rank

def worker_method(self, worker_input: Any) -> Tuple[int, Any]:
sleep(0.05)

if isinstance(worker_input, Exception):
# simulate error case
raise worker_input

return self.rank, input


def _start_workers() -> Tuple[List[LocalWorkerVllm], WorkerMonitor]:
result_handler = ResultHandler()
workers = [
LocalWorkerVllm(result_handler, partial(DummyWorker, rank=rank))
for rank in range(8)
]

worker_monitor = WorkerMonitor(workers, result_handler)
assert not worker_monitor.is_alive()

result_handler.start()
worker_monitor.start()
assert worker_monitor.is_alive()

return workers, worker_monitor


def test_local_workers() -> None:
"""Test workers with sync task submission"""

workers, worker_monitor = _start_workers()

def execute_workers(worker_input: str) -> None:
worker_outputs = [
worker.execute_method("worker_method", worker_input)
for worker in workers
]

for rank, output in enumerate(worker_outputs):
assert output.get() == (rank, input)

executor = ThreadPoolExecutor(max_workers=4)

# Test concurrent submission from different threads
futures = [
executor.submit(partial(execute_workers, f"thread {thread_num}"))
for thread_num in range(4)
]

for future in futures:
future.result()

# Test error case
exception = ValueError("fake error")
result = workers[0].execute_method("worker_method", exception)
try:
result.get()
pytest.fail("task should have failed")
except Exception as e:
assert isinstance(e, ValueError)
assert str(e) == "fake error"

# Test cleanup when a worker fails
assert worker_monitor.is_alive()
workers[3].process.kill()

# Other workers should get shut down here
worker_monitor.join(2)

# Ensure everything is stopped
assert not worker_monitor.is_alive()
assert all(not worker.process.is_alive() for worker in workers)

# Further attempts to submit tasks should fail
try:
_result = workers[0].execute_method("worker_method", "test")
pytest.fail("task should fail once workers have been shut down")
except Exception as e:
assert isinstance(e, ChildProcessError)


def test_local_workers_clean_shutdown() -> None:
"""Test clean shutdown"""

workers, worker_monitor = _start_workers()

assert worker_monitor.is_alive()
assert all(worker.process.is_alive() for worker in workers)

# Clean shutdown
worker_monitor.close()

worker_monitor.join(2)

# Ensure everything is stopped
assert not worker_monitor.is_alive()
assert all(not worker.process.is_alive() for worker in workers)

# Further attempts to submit tasks should fail
try:
_result = workers[0].execute_method("worker_method", "test")
pytest.fail("task should fail once workers have been shut down")
except Exception as e:
assert isinstance(e, ChildProcessError)


@pytest.mark.asyncio
async def test_local_workers_async() -> None:
"""Test local workers with async task submission"""

workers, worker_monitor = _start_workers()

async def execute_workers(worker_input: str) -> None:
worker_coros = [
worker.execute_method_async("worker_method", worker_input)
for worker in workers
]

results = await asyncio.gather(*worker_coros)
for rank, result in enumerate(results):
assert result == (rank, input)

tasks = [
asyncio.create_task(execute_workers(f"task {task_num}"))
for task_num in range(4)
]

for task in tasks:
await task

# Test error case
exception = ValueError("fake error")
try:
_result = await workers[0].execute_method_async(
"worker_method", exception)
pytest.fail("task should have failed")
except Exception as e:
assert isinstance(e, ValueError)
assert str(e) == "fake error"

# Test cleanup when a worker fails
assert worker_monitor.is_alive()
workers[3].process.kill()

# Other workers should get shut down here
worker_monitor.join(2)

# Ensure everything is stopped
assert not worker_monitor.is_alive()
assert all(not worker.process.is_alive() for worker in workers)

# Further attempts to submit tasks should fail
try:
_result = await workers[0].execute_method_async(
"worker_method", "test")
pytest.fail("task should fail once workers have been shut down")
except Exception as e:
assert isinstance(e, ChildProcessError)
5 changes: 2 additions & 3 deletions tests/spec_decode/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from vllm.sequence import (Logprob, SamplerOutput, SequenceData,
SequenceGroupMetadata, SequenceGroupOutput,
SequenceOutput)
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.utils import get_distributed_init_method
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.worker import Worker

Expand Down Expand Up @@ -112,8 +112,7 @@ def create_worker(cls: type,
)
engine_config = engine_args.create_engine_config()

distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
distributed_init_method = get_distributed_init_method()

worker = cls(
model_config=engine_config.model_config,
Expand Down
5 changes: 2 additions & 3 deletions tests/worker/test_swap.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch

from vllm.engine.arg_utils import EngineArgs
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.utils import get_distributed_init_method
from vllm.worker.worker import Worker


Expand All @@ -15,8 +15,7 @@ def test_swap() -> None:
engine_config.cache_config.num_cpu_blocks = 1000

# Create the worker.
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
distributed_init_method = get_distributed_init_method()
worker = Worker(
model_config=engine_config.model_config,
parallel_config=engine_config.parallel_config,
Expand Down
2 changes: 1 addition & 1 deletion vllm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.llm_engine import LLMEngine
from vllm.engine.ray_utils import initialize_ray_cluster
from vllm.entrypoints.llm import LLM
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.model_executor.models import ModelRegistry
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import SamplingParams
Expand Down
2 changes: 1 addition & 1 deletion vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from vllm.config import ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.engine.ray_utils import initialize_ray_cluster, ray
from vllm.executor.ray_utils import initialize_ray_cluster, ray
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
Expand Down
12 changes: 10 additions & 2 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.engine.output_processor.util import create_output_by_sequence_group
from vllm.engine.ray_utils import initialize_ray_cluster
from vllm.executor.executor_base import ExecutorBase
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
Expand All @@ -28,7 +28,7 @@
get_tokenizer_group)
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message)
from vllm.utils import Counter
from vllm.utils import Counter, enable_trace_function_call_for_thread

logger = init_logger(__name__)
_LOCAL_LOGGING_INTERVAL_SEC = 5
Expand Down Expand Up @@ -133,6 +133,8 @@ def __init__(
self.decoding_config = decoding_config or DecodingConfig()
self.log_stats = log_stats

enable_trace_function_call_for_thread()

if not self.model_config.skip_tokenizer_init:
self.tokenizer: BaseTokenizerGroup
self._init_tokenizer()
Expand Down Expand Up @@ -287,6 +289,12 @@ def __reduce__(self):
# the closure used to initialize Ray worker actors
raise RuntimeError("LLMEngine should not be pickled!")

def __del__(self):
# Shutdown model executor when engine is garbage collected
# Use getattr since __init__ can fail before the field is set
if model_executor := getattr(self, "model_executor", None):
model_executor.shutdown()

def get_tokenizer(self) -> "PreTrainedTokenizer":
return self.tokenizer.get_lora_tokenizer(None)

Expand Down
6 changes: 2 additions & 4 deletions vllm/executor/cpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
make_async)
from vllm.utils import get_distributed_init_method, make_async

logger = init_logger(__name__)

Expand All @@ -33,8 +32,7 @@ def _init_worker(self):
assert self.parallel_config.world_size == 1, (
"CPUExecutor only supports single CPU socket currently.")

distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
distributed_init_method = get_distributed_init_method()
self.driver_worker = CPUWorker(
model_config=self.model_config,
parallel_config=self.parallel_config,
Expand Down
7 changes: 7 additions & 0 deletions vllm/executor/executor_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,13 @@ def check_health(self) -> None:
exception."""
raise NotImplementedError

def shutdown(self) -> None:
"""Shutdown the executor."""
return

def __del__(self):
self.shutdown()


class ExecutorAsyncBase(ExecutorBase):

Expand Down
Loading
Loading