diff --git a/tests/entrypoints/test_api_server_process_manager.py b/tests/entrypoints/test_api_server_process_manager.py new file mode 100644 index 00000000000..0dd1fdd9969 --- /dev/null +++ b/tests/entrypoints/test_api_server_process_manager.py @@ -0,0 +1,268 @@ +# SPDX-License-Identifier: Apache-2.0 + +import multiprocessing +import socket +import threading +import time +from typing import Optional +from unittest.mock import patch + +import pytest + +from vllm.v1.utils import (APIServerProcessManager, + wait_for_completion_or_failure) + +# Global variables to control worker behavior +WORKER_RUNTIME_SECONDS = 0.5 + + +# Mock implementation of run_api_server_worker +def mock_run_api_server_worker(listen_address, sock, args, client_config=None): + """Mock run_api_server_worker that runs for a specific time.""" + print(f"Mock worker started with client_config: {client_config}") + time.sleep(WORKER_RUNTIME_SECONDS) + print("Mock worker completed successfully") + + +@pytest.fixture +def api_server_args(): + """Fixture to provide arguments for APIServerProcessManager.""" + sock = socket.socket() + return { + "target_server_fn": + mock_run_api_server_worker, + "listen_address": + "localhost:8000", + "sock": + sock, + "args": + "test_args", # Simple string to avoid pickling issues + "num_servers": + 3, + "input_addresses": [ + "tcp://127.0.0.1:5001", "tcp://127.0.0.1:5002", + "tcp://127.0.0.1:5003" + ], + "output_addresses": [ + "tcp://127.0.0.1:6001", "tcp://127.0.0.1:6002", + "tcp://127.0.0.1:6003" + ], + "stats_update_address": + "tcp://127.0.0.1:7000", + } + + +@pytest.mark.parametrize("with_stats_update", [True, False]) +def test_api_server_process_manager_init(api_server_args, with_stats_update): + """Test initializing the APIServerProcessManager.""" + # Set the worker runtime to ensure tests complete in reasonable time + global WORKER_RUNTIME_SECONDS + WORKER_RUNTIME_SECONDS = 0.5 + + # Copy the args to avoid mutating the + args = api_server_args.copy() + + if not with_stats_update: + args.pop("stats_update_address") + manager = APIServerProcessManager(**args) + + try: + # Verify the manager was initialized correctly + assert len(manager.processes) == 3 + + # Verify all processes are running + for proc in manager.processes: + assert proc.is_alive() + + print("Waiting for processes to run...") + time.sleep(WORKER_RUNTIME_SECONDS / 2) + + # They should still be alive at this point + for proc in manager.processes: + assert proc.is_alive() + + finally: + # Always clean up the processes + print("Cleaning up processes...") + manager.close() + + # Give processes time to terminate + time.sleep(0.2) + + # Verify all processes were terminated + for proc in manager.processes: + assert not proc.is_alive() + + +@patch("vllm.entrypoints.cli.serve.run_api_server_worker", + mock_run_api_server_worker) +def test_wait_for_completion_or_failure(api_server_args): + """Test that wait_for_completion_or_failure works with failures.""" + global WORKER_RUNTIME_SECONDS + WORKER_RUNTIME_SECONDS = 1.0 + + # Create the manager + manager = APIServerProcessManager(**api_server_args) + + try: + assert len(manager.processes) == 3 + + # Create a result capture for the thread + result: dict[str, Optional[Exception]] = {"exception": None} + + def run_with_exception_capture(): + try: + wait_for_completion_or_failure(api_server_manager=manager) + except Exception as e: + result["exception"] = e + + # Start a thread to run wait_for_completion_or_failure + wait_thread = threading.Thread(target=run_with_exception_capture, + daemon=True) + wait_thread.start() + + # Let all processes run for a short time + time.sleep(0.2) + + # All processes should still be running + assert all(proc.is_alive() for proc in manager.processes) + + # Now simulate a process failure + print("Simulating process failure...") + manager.processes[0].terminate() + + # Wait for the wait_for_completion_or_failure + # to detect and handle the failure + # This should trigger it to terminate all other processes + wait_thread.join(timeout=1.0) + + # The wait thread should have exited + assert not wait_thread.is_alive() + + # Verify that an exception was raised with appropriate error message + assert result["exception"] is not None + assert "died with exit code" in str(result["exception"]) + + # All processes should now be terminated + for i, proc in enumerate(manager.processes): + assert not proc.is_alive(), f"Process {i} should not be alive" + + finally: + manager.close() + time.sleep(0.2) + + +@pytest.mark.timeout(30) +def test_normal_completion(api_server_args): + """Test that wait_for_completion_or_failure works in normal completion.""" + global WORKER_RUNTIME_SECONDS + WORKER_RUNTIME_SECONDS = 0.1 + + # Create the manager + manager = APIServerProcessManager(**api_server_args) + + try: + # Give processes time to terminate + # wait for processes to complete + remaining_processes = manager.processes.copy() + while remaining_processes: + for proc in remaining_processes: + if not proc.is_alive(): + remaining_processes.remove(proc) + time.sleep(0.1) + + # Verify all processes have terminated + for i, proc in enumerate(manager.processes): + assert not proc.is_alive( + ), f"Process {i} still alive after terminate()" + + # Now call wait_for_completion_or_failure + # since all processes have already + # terminated, it should return immediately + # with no error + wait_for_completion_or_failure(api_server_manager=manager) + + finally: + # Clean up just in case + manager.close() + time.sleep(0.2) + + +@pytest.mark.timeout(30) +def test_external_process_monitoring(api_server_args): + """Test that wait_for_completion_or_failure handles additional processes.""" + global WORKER_RUNTIME_SECONDS + WORKER_RUNTIME_SECONDS = 100 + + # Create and start the external process + # (simulates local_engine_manager or coordinator) + spawn_context = multiprocessing.get_context("spawn") + external_proc = spawn_context.Process(target=mock_run_api_server_worker, + name="MockExternalProcess") + external_proc.start() + + # Create the class to simulate a coordinator + class MockCoordinator: + + def __init__(self, proc): + self.proc = proc + + def close(self): + if self.proc.is_alive(): + self.proc.terminate() + self.proc.join(timeout=0.5) + + # Create a mock coordinator with the external process + mock_coordinator = MockCoordinator(external_proc) + + # Create the API server manager + manager = APIServerProcessManager(**api_server_args) + + try: + # Verify manager initialization + assert len(manager.processes) == 3 + + # Create a result capture for the thread + result: dict[str, Optional[Exception]] = {"exception": None} + + def run_with_exception_capture(): + try: + wait_for_completion_or_failure(api_server_manager=manager, + coordinator=mock_coordinator) + except Exception as e: + result["exception"] = e + + # Start a thread to run wait_for_completion_or_failure + wait_thread = threading.Thread(target=run_with_exception_capture, + daemon=True) + wait_thread.start() + + # Terminate the external process to trigger a failure + time.sleep(0.2) + external_proc.terminate() + + # Wait for the thread to detect the failure + wait_thread.join(timeout=1.0) + + # The wait thread should have completed + assert not wait_thread.is_alive( + ), "wait_for_completion_or_failure thread still running" + + # Verify that an exception was raised with appropriate error message + assert result["exception"] is not None, "No exception was raised" + error_message = str(result["exception"]) + assert "died with exit code" in error_message, \ + f"Unexpected error message: {error_message}" + assert "MockExternalProcess" in error_message, \ + f"Error doesn't mention external process: {error_message}" + + # Verify that all API server processes were terminated as a result + for i, proc in enumerate(manager.processes): + assert not proc.is_alive( + ), f"API server process {i} was not terminated" + + finally: + # Clean up + manager.close() + mock_coordinator.close() + time.sleep(0.2) diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 43a27da2dbe..d3d62cf0923 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -45,7 +45,6 @@ def make_request(request_id, multi_modal_placeholders=mm_positions, sampling_params=SamplingParams(max_tokens=17), eos_token_id=100, - arrival_time=0, lora_request=None, cache_salt=cache_salt, ) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 3da27786b1f..ba3c0b3cf31 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -38,7 +38,6 @@ def make_request(request_id, sampling_params=SamplingParams(max_tokens=17, prompt_logprobs=prompt_logprobs), eos_token_id=100, - arrival_time=0, lora_request=None, cache_salt=cache_salt, ) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index f40d477a003..80a8978cdca 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -138,7 +138,6 @@ def create_requests(num_requests: int, multi_modal_placeholders=mm_position, multi_modal_hashes=None, eos_token_id=EOS_TOKEN_ID, - arrival_time=0, ) requests.append(request) return requests @@ -732,7 +731,7 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected): prompt_logprobs_dict={}, ) engine_core_outputs = scheduler.update_from_output(output, - model_runner_output) + model_runner_output)[0] for i in range(len(requests)): running_req = scheduler.running[i] diff --git a/tests/v1/test_async_llm_dp.py b/tests/v1/test_async_llm_dp.py index ce4c4d198db..49dc1cf8e06 100644 --- a/tests/v1/test_async_llm_dp.py +++ b/tests/v1/test_async_llm_dp.py @@ -19,8 +19,9 @@ model="ibm-research/PowerMoE-3b", enforce_eager=True, disable_log_requests=True, - tensor_parallel_size=int(os.getenv("TP_SIZE", 1)), + tensor_parallel_size=int(os.getenv("TP_SIZE", 2)), data_parallel_size=int(os.getenv("DP_SIZE", 2)), + data_parallel_address="172.31.15.128", ) if not current_platform.supports_v1(engine_args.create_model_config()): @@ -59,14 +60,22 @@ async def generate(engine: AsyncLLM, @pytest.mark.parametrize( - "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]) + "output_kind", + [ + RequestOutputKind.DELTA, + # RequestOutputKind.FINAL_ONLY, + ], +) +@pytest.mark.parametrize("data_parallel_backend", ["ray"]) @pytest.mark.asyncio -async def test_load(output_kind: RequestOutputKind): +async def test_load(output_kind: RequestOutputKind, + data_parallel_backend: str): with ExitStack() as after: prompt = "This is a test of data parallel" + engine_args.data_parallel_backend = data_parallel_backend engine = AsyncLLM.from_engine_args(engine_args) after.callback(engine.shutdown) @@ -82,7 +91,6 @@ async def test_load(output_kind: RequestOutputKind): asyncio.create_task( generate(engine, request_id, prompt, output_kind, NUM_EXPECTED_TOKENS))) - # Confirm that we got all the EXPECTED tokens from the requests. done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_EXCEPTION) diff --git a/vllm/config.py b/vllm/config.py index a185a75c6bf..183a8ac3912 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1693,6 +1693,8 @@ class ParallelConfig: """Port for data parallel messaging.""" data_parallel_master_port: int = 29500 """Port of the data parallel master.""" + data_parallel_backend: str = "mp" + """Backend to use for data parallel, either "mp" or "ray".""" enable_expert_parallel: bool = False """Use expert parallelism instead of tensor parallelism for MoE layers.""" max_parallel_loading_workers: Optional[int] = None @@ -1856,6 +1858,10 @@ def __post_init__(self) -> None: "please install Ray with `pip install " "ray`.") from ray_utils.ray_import_err backend = "ray" + elif self.data_parallel_backend == "ray": + logger.info("Using ray distributed inference because " + "data_parallel_backend is ray") + backend = "ray" elif ray_found: if self.placement_group: backend = "ray" diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index f0c6b15b79d..3c76bac396b 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -290,6 +290,7 @@ class EngineArgs: data_parallel_size_local: Optional[int] = None data_parallel_address: Optional[str] = None data_parallel_rpc_port: Optional[int] = None + data_parallel_backend: str = ParallelConfig.data_parallel_backend enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel max_parallel_loading_workers: Optional[ int] = ParallelConfig.max_parallel_loading_workers @@ -618,6 +619,11 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: type=int, help='Port for data parallel RPC ' 'communication.') + parallel_group.add_argument('--data-parallel-backend', + '-dpb', + type=str, + help='Backend for data parallel, either ' + '"mp" or "ray".') parallel_group.add_argument( "--enable-expert-parallel", **parallel_kwargs["enable_expert_parallel"]) @@ -1058,6 +1064,8 @@ def create_engine_config( self.data_parallel_rpc_port is not None) else ParallelConfig.data_parallel_rpc_port + data_parallel_backend = self.data_parallel_backend + parallel_config = ParallelConfig( pipeline_parallel_size=self.pipeline_parallel_size, tensor_parallel_size=self.tensor_parallel_size, @@ -1065,6 +1073,7 @@ def create_engine_config( data_parallel_size_local=data_parallel_size_local, data_parallel_master_ip=data_parallel_address, data_parallel_rpc_port=data_parallel_rpc_port, + data_parallel_backend=data_parallel_backend, enable_expert_parallel=self.enable_expert_parallel, max_parallel_loading_workers=self.max_parallel_loading_workers, disable_custom_all_reduce=self.disable_custom_all_reduce, diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index 04be7c03399..f7362499fe1 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -1,25 +1,42 @@ # SPDX-License-Identifier: Apache-2.0 import argparse +import os import signal +import socket +import sys +from multiprocessing import connection +from typing import Any, Union import uvloop +import zmq import vllm.envs as envs from vllm import AsyncEngineArgs from vllm.entrypoints.cli.types import CLISubcommand -from vllm.entrypoints.openai.api_server import run_server +from vllm.entrypoints.openai.api_server import (run_server, run_server_worker, + setup_server) from vllm.entrypoints.openai.cli_args import (make_arg_parser, validate_parsed_serve_args) +from vllm.executor.multiproc_worker_utils import _add_prefix from vllm.logger import init_logger from vllm.usage.usage_lib import UsageContext -from vllm.utils import FlexibleArgumentParser, get_tcp_uri +from vllm.utils import FlexibleArgumentParser, get_tcp_uri, zmq_socket_ctx +from vllm.v1.engine.coordinator import DPCoordinator from vllm.v1.engine.core import EngineCoreProc from vllm.v1.engine.core_client import CoreEngineProcManager from vllm.v1.executor.abstract import Executor +from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus +from vllm.v1.utils import (APIServerProcessManager, CoreEngine, + CoreEngineActorManager, get_engine_client_zmq_addr, + wait_for_completion_or_failure, + wait_for_engine_startup, wait_for_ray_engine_actors) logger = init_logger(__name__) +# Type for process sentinel objects +SentinelType = Union[connection.Connection, socket.socket, int] + class ServeSubcommand(CLISubcommand): """The `serve` subcommand for the vLLM CLI. """ @@ -34,9 +51,12 @@ def cmd(args: argparse.Namespace) -> None: if hasattr(args, 'model_tag') and args.model_tag is not None: args.model = args.model_tag - if args.headless: + if args.headless or args.api_server_count < 1: run_headless(args) + elif args.api_server_count > 1: + run_multi_api_server(args) else: + # Single API server (this process). uvloop.run(run_server(args)) def validate(self, args: argparse.Namespace) -> None: @@ -67,6 +87,11 @@ def subparser_init( type=int, default=0, help='Starting data parallel rank for secondary nodes.') + serve_parser.add_argument('--api-server-count', + '-asc', + type=int, + default=1, + help='How many API server processes to run.') serve_parser.add_argument( "--config", type=str, @@ -86,6 +111,9 @@ def cmd_init() -> list[CLISubcommand]: def run_headless(args: argparse.Namespace): + if args.api_server_count > 1: + raise RuntimeError("api_server_count can't be set in headless mode") + # Create the EngineConfig. engine_args = AsyncEngineArgs.from_cli_args(args) usage_context = UsageContext.OPENAI_API_SERVER @@ -98,7 +126,7 @@ def run_headless(args: argparse.Namespace): local_engine_count = parallel_config.data_parallel_size_local host = parallel_config.data_parallel_master_ip port = engine_args.data_parallel_rpc_port # add to config too - input_address = get_tcp_uri(host, port) + handshake_address = get_tcp_uri(host, port) if local_engine_count <= 0: raise RuntimeError("data_parallel_size_local must be > 0 in " @@ -114,7 +142,7 @@ def signal_handler(signum, frame): logger.info( "Launching %d data parallel engine(s) in headless mode, " - "with head node address %s.", local_engine_count, input_address) + "with head node address %s.", local_engine_count, handshake_address) # Create the engines. engine_manager = CoreEngineProcManager( @@ -124,7 +152,7 @@ def signal_handler(signum, frame): local_start_index=0, vllm_config=vllm_config, on_head_node=False, - input_address=input_address, + handshake_address=handshake_address, executor_class=Executor.get_class(vllm_config), log_stats=not engine_args.disable_log_stats, ) @@ -134,3 +162,160 @@ def signal_handler(signum, frame): finally: logger.info("Shutting down.") engine_manager.close() + + +def run_multi_api_server(args: argparse.Namespace): + + assert not args.headless + num_api_servers = args.api_server_count + #assert num_api_servers > 1 + + if num_api_servers > 1: + setup_multiprocess_prometheus() + + listen_address, sock = setup_server(args) + + engine_args = AsyncEngineArgs.from_cli_args(args) + usage_context = UsageContext.OPENAI_API_SERVER + vllm_config = engine_args.create_engine_config(usage_context=usage_context) + model_config = vllm_config.model_config + + if num_api_servers > 1 and model_config.is_multimodal_model and not ( + model_config.disable_mm_preprocessor_cache): + logger.warning("Multi-model preprocessor cache will be disabled for" + " api_server_count > 1") + model_config.disable_mm_preprocessor_cache = True + + parallel_config = vllm_config.parallel_config + + assert parallel_config.data_parallel_rank == 0 + + dp_size = parallel_config.data_parallel_size + local_engine_count = parallel_config.data_parallel_size_local + host = parallel_config.data_parallel_master_ip + local_only = local_engine_count == dp_size + + # Set up input and output addresses. + input_addresses = [ + get_engine_client_zmq_addr(local_only, host) + for _ in range(num_api_servers) + ] + output_addresses = [ + get_engine_client_zmq_addr(local_only, host) + for _ in range(num_api_servers) + ] + + addresses: dict[str, Any] = { + "input_addresses": input_addresses, + "output_addresses": output_addresses, + } + + # Set up coordinator for dp > 1. + coordinator = None + stats_update_address = None + if dp_size > 1: + # TODO "ready" event for coordinator + coordinator = DPCoordinator(parallel_config) + addresses.update(coordinator.get_engine_socket_addresses()) + stats_update_address = coordinator.get_stats_publish_address() + logger.info("Started DP Coordinator process (PID: %d)", + coordinator.proc.pid) + + if parallel_config.data_parallel_backend == "ray": + logger.info("Starting ray-based data parallel backend") + + engine_actor_manager = CoreEngineActorManager( + local_engine_count=local_engine_count, + start_index=args.data_parallel_start_rank, + local_start_index=0, + vllm_config=vllm_config, + addresses=addresses, + executor_class=Executor.get_class(vllm_config), + log_stats=not engine_args.disable_log_stats, + ) + # Start API servers using the manager + api_server_manager = APIServerProcessManager( + target_server_fn=run_api_server_worker, + listen_address=listen_address, + sock=sock, + args=args, + num_servers=num_api_servers, + input_addresses=input_addresses, + output_addresses=output_addresses, + stats_update_address=stats_update_address) + + wait_for_ray_engine_actors(api_server_manager=api_server_manager, + engine_actor_manager=engine_actor_manager, + coordinator=coordinator) + return + + handshake_address = get_engine_client_zmq_addr( + local_only, host, parallel_config.data_parallel_rpc_port) + + with zmq_socket_ctx(handshake_address, zmq.ROUTER, + bind=True) as handshake_socket: + + # Start local engines. + if not local_engine_count: + local_engine_manager = None + else: + local_engine_manager = CoreEngineProcManager( + EngineCoreProc.run_engine_core, + vllm_config=vllm_config, + executor_class=Executor.get_class(vllm_config), + log_stats=not engine_args.disable_log_stats, + handshake_address=handshake_address, + on_head_node=True, + local_engine_count=local_engine_count, + start_index=0, + local_start_index=0) + + # Start API servers using the manager + api_server_manager = APIServerProcessManager( + target_server_fn=run_api_server_worker, + listen_address=listen_address, + sock=sock, + args=args, + num_servers=num_api_servers, + input_addresses=input_addresses, + output_addresses=output_addresses, + stats_update_address=stats_update_address) + + # Wait for engine handshakes to complete. + core_engines = [ + CoreEngine(index=i, local=(i < local_engine_count)) + for i in range(dp_size) + ] + wait_for_engine_startup( + handshake_socket, + addresses, + core_engines, + parallel_config, + vllm_config.cache_config, + local_engine_manager, + coordinator.proc if coordinator else None, + ) + + # Wait for API servers + wait_for_completion_or_failure( + api_server_manager=api_server_manager, + local_engine_manager=local_engine_manager, + coordinator=coordinator) + + +def run_api_server_worker(listen_address, + sock, + args, + client_config=None, + **uvicorn_kwargs) -> None: + + # Add process-specific prefix to stdout and stderr. + from multiprocessing import current_process + process_name = current_process().name + pid = os.getpid() + _add_prefix(sys.stdout, process_name, pid) + _add_prefix(sys.stderr, process_name, pid) + + uvloop.run( + run_server_worker(listen_address, sock, args, client_config, + **uvicorn_kwargs)) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 0ab6fcdca1a..dca20af76c7 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -18,7 +18,7 @@ from functools import partial from http import HTTPStatus from json import JSONDecodeError -from typing import Annotated, Optional, Union +from typing import Annotated, Any, Optional, Union import prometheus_client import uvloop @@ -26,6 +26,8 @@ from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response, StreamingResponse +from prometheus_client import make_asgi_app +from prometheus_fastapi_instrumentator import Instrumentator from starlette.concurrency import iterate_in_threadpool from starlette.datastructures import State from starlette.routing import Mount @@ -99,6 +101,7 @@ from vllm.usage.usage_lib import UsageContext from vllm.utils import (Device, FlexibleArgumentParser, get_open_zmq_ipc_path, is_valid_ipv6_address, set_ulimit) +from vllm.v1.metrics.prometheus import get_prometheus_registry from vllm.version import __version__ as VLLM_VERSION TIMEOUT_KEEP_ALIVE = 5 # seconds @@ -144,14 +147,17 @@ async def _force_log(): @asynccontextmanager async def build_async_engine_client( - args: Namespace) -> AsyncIterator[EngineClient]: + args: Namespace, + client_config: Optional[dict[str, Any]] = None, +) -> AsyncIterator[EngineClient]: # Context manager to handle engine_client lifecycle # Ensures everything is shutdown and cleaned up on error/exit engine_args = AsyncEngineArgs.from_cli_args(args) async with build_async_engine_client_from_engine_args( - engine_args, args.disable_frontend_multiprocessing) as engine: + engine_args, args.disable_frontend_multiprocessing, + client_config) as engine: yield engine @@ -159,6 +165,7 @@ async def build_async_engine_client( async def build_async_engine_client_from_engine_args( engine_args: AsyncEngineArgs, disable_frontend_multiprocessing: bool = False, + client_config: Optional[dict[str, Any]] = None, ) -> AsyncIterator[EngineClient]: """ Create EngineClient, either: @@ -181,12 +188,16 @@ async def build_async_engine_client_from_engine_args( from vllm.v1.engine.async_llm import AsyncLLM async_llm: Optional[AsyncLLM] = None + client_index = client_config.pop( + "client_index") if client_config else 0 try: async_llm = AsyncLLM.from_vllm_config( vllm_config=vllm_config, usage_context=usage_context, disable_log_requests=engine_args.disable_log_requests, - disable_log_stats=engine_args.disable_log_stats) + disable_log_stats=engine_args.disable_log_stats, + client_addresses=client_config, + client_index=client_index) # Don't keep the dummy data in memory await async_llm.reset_mm_cache() @@ -320,22 +331,9 @@ class PrometheusResponse(Response): def mount_metrics(app: FastAPI): - # Lazy import for prometheus multiprocessing. - # We need to set PROMETHEUS_MULTIPROC_DIR environment variable - # before prometheus_client is imported. - # See https://prometheus.github.io/client_python/multiprocess/ - from prometheus_client import (REGISTRY, CollectorRegistry, make_asgi_app, - multiprocess) - from prometheus_fastapi_instrumentator import Instrumentator - - registry = REGISTRY - - prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None) - if prometheus_multiproc_dir_path is not None: - logger.debug("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR", - prometheus_multiproc_dir_path) - registry = CollectorRegistry() - multiprocess.MultiProcessCollector(registry) + """Mount prometheus metrics to a FastAPI app.""" + + registry = get_prometheus_registry() # `response_class=PrometheusResponse` is needed to return an HTTP response # with header "Content-Type: text/plain; version=0.0.4; charset=utf-8" @@ -1285,16 +1283,10 @@ def create_server_socket(addr: tuple[str, int]) -> socket.socket: return sock -async def run_server(args, **uvicorn_kwargs) -> None: - logger.info("vLLM API server version %s", VLLM_VERSION) - log_non_default_args(args) - - if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: - ToolParserManager.import_tool_parser(args.tool_parser_plugin) - +def validate_api_server_args(args): valid_tool_parses = ToolParserManager.tool_parsers.keys() if args.enable_auto_tool_choice \ - and args.tool_call_parser not in valid_tool_parses: + and args.tool_call_parser not in valid_tool_parses: raise KeyError(f"invalid tool call parser: {args.tool_call_parser} " f"(chose from {{ {','.join(valid_tool_parses)} }})") @@ -1305,6 +1297,16 @@ async def run_server(args, **uvicorn_kwargs) -> None: f"invalid reasoning parser: {args.reasoning_parser} " f"(chose from {{ {','.join(valid_reasoning_parses)} }})") + +def setup_server(args): + logger.info("vLLM API server version %s", VLLM_VERSION) + log_non_default_args(args) + + if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: + ToolParserManager.import_tool_parser(args.tool_parser_plugin) + + validate_api_server_args(args) + # workaround to make sure that we bind the port before the engine is set up. # This avoids race conditions with ray. # see https://github.com/vllm-project/vllm/issues/8204 @@ -1321,22 +1323,39 @@ def signal_handler(*_) -> None: signal.signal(signal.SIGTERM, signal_handler) - async with build_async_engine_client(args) as engine_client: + addr, port = sock_addr + is_ssl = args.ssl_keyfile and args.ssl_certfile + host_part = f"[{addr}]" if is_valid_ipv6_address( + addr) else addr or "0.0.0.0" + listen_address = f"http{'s' if is_ssl else ''}://{host_part}:{port}" + + return listen_address, sock + + +async def run_server(args, **uvicorn_kwargs) -> None: + listen_address, sock = setup_server(args) + await run_server_worker(listen_address, sock, args, **uvicorn_kwargs) + + +async def run_server_worker(listen_address, + sock, + args, + client_config=None, + **uvicorn_kwargs) -> None: + + if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: + ToolParserManager.import_tool_parser(args.tool_parser_plugin) + + server_index = client_config.get("client_index", 0) if client_config else 0 + + async with build_async_engine_client(args, client_config) as engine_client: app = build_app(args) vllm_config = await engine_client.get_vllm_config() await init_app_state(engine_client, vllm_config, app.state, args) - def _listen_addr(a: str) -> str: - if is_valid_ipv6_address(a): - return '[' + a + ']' - return a or "0.0.0.0" - - is_ssl = args.ssl_keyfile and args.ssl_certfile - logger.info("Starting vLLM API server on http%s://%s:%d", - "s" if is_ssl else "", _listen_addr(sock_addr[0]), - sock_addr[1]) - + logger.info("Starting vLLM API server %d on %s", server_index, + listen_address) shutdown_task = await serve_http( app, sock=sock, diff --git a/vllm/utils.py b/vllm/utils.py index 0cd90c130d3..b269ecbc877 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -2421,6 +2421,7 @@ def make_zmq_socket( socket_type: Any, bind: Optional[bool] = None, identity: Optional[bytes] = None, + linger: Optional[int] = None, ) -> Union[zmq.Socket, zmq.asyncio.Socket]: # type: ignore[name-defined] """Make a ZMQ socket with the proper bind/connect semantics.""" @@ -2440,7 +2441,7 @@ def make_zmq_socket( buf_size = -1 # Use system default buffer size if bind is None: - bind = socket_type != zmq.PUSH + bind = socket_type not in (zmq.PUSH, zmq.SUB, zmq.XSUB) if socket_type in (zmq.PULL, zmq.DEALER, zmq.ROUTER): socket.setsockopt(zmq.RCVHWM, 0) @@ -2453,6 +2454,9 @@ def make_zmq_socket( if identity is not None: socket.setsockopt(zmq.IDENTITY, identity) + if linger is not None: + socket.setsockopt(zmq.LINGER, linger) + # Determine if the path is a TCP socket with an IPv6 address. # Enable IPv6 on the zmq socket if so. scheme, host, _ = split_zmq_path(path) diff --git a/vllm/v1/core/sched/interface.py b/vllm/v1/core/sched/interface.py index c17f80b6ae7..055ce446051 100644 --- a/vllm/v1/core/sched/interface.py +++ b/vllm/v1/core/sched/interface.py @@ -45,7 +45,7 @@ def update_from_output( self, scheduler_output: "SchedulerOutput", model_runner_output: "ModelRunnerOutput", - ) -> "EngineCoreOutputs": + ) -> dict[int, "EngineCoreOutputs"]: """Update the scheduler state based on the model runner output. This method is called after the model runner has processed the scheduled @@ -55,7 +55,8 @@ def update_from_output( for each request. Returns: - A EngineCoreOutputs object containing the outputs for each request. + A dict of client index to EngineCoreOutputs object containing the + outputs for each request originating from that client. """ raise NotImplementedError @@ -126,6 +127,11 @@ def reset_prefix_cache(self) -> bool: """ raise NotImplementedError + @abstractmethod + def get_request_counts(self) -> tuple[int, int]: + """Returns (num_running_reqs, num_waiting_reqs).""" + raise NotImplementedError + @abstractmethod def make_stats(self) -> Optional["SchedulerStats"]: """Make a SchedulerStats object for logging. diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 5ad05485e8f..a89e948f744 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -58,7 +58,8 @@ def __init__( # request ids should be included in the EngineCoreOutputs returned # by update_from_outputs(). This is currently used in the multi-engine # case to track request lifetimes efficiently. - self.include_finished_set = include_finished_set + self.finished_req_ids_dict: Optional[dict[int, set[str]]] = ( + defaultdict(set) if include_finished_set else None) # Scheduling constraints. self.max_num_running_reqs = self.scheduler_config.max_num_seqs @@ -678,7 +679,7 @@ def update_from_output( self, scheduler_output: SchedulerOutput, model_runner_output: ModelRunnerOutput, - ) -> EngineCoreOutputs: + ) -> dict[int, EngineCoreOutputs]: sampled_token_ids = model_runner_output.sampled_token_ids spec_token_ids = model_runner_output.spec_token_ids logprobs = model_runner_output.logprobs @@ -686,7 +687,7 @@ def update_from_output( num_scheduled_tokens = scheduler_output.num_scheduled_tokens new_running: list[Request] = [] - outputs: list[EngineCoreOutput] = [] + outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list) spec_decoding_stats: Optional[SpecDecodingStats] = None # NOTE(woosuk): As len(self.running) can be up to 1K or more, the below @@ -782,7 +783,7 @@ def update_from_output( if new_token_ids or kv_transfer_params: # Add EngineCoreOutput for this Request. - outputs.append( + outputs[request.client_index].append( EngineCoreOutput( request_id=req_id, new_token_ids=new_token_ids, @@ -812,17 +813,35 @@ def update_from_output( self._cached_reqs_data[req_data.req_id].append(req_data) self.running = new_running - engine_core_outputs = EngineCoreOutputs( - outputs=outputs, - scheduler_stats=self.make_stats(spec_decoding_stats), - ) - if self.include_finished_set: - #TODO currently sending duplicates here, improve this - engine_core_outputs.finished_requests = ( - scheduler_output.finished_req_ids | self.finished_req_ids) + + engine_core_outputs = { + client_index: EngineCoreOutputs(outputs=outs) + for client_index, outs in outputs.items() + } + + finished_req_ids = self.finished_req_ids_dict + if finished_req_ids is not None: + # Include ids of requests that finished since last outputs + # were sent. + for client_index, finished_set in finished_req_ids.items(): + if (eco := engine_core_outputs.get(client_index)) is not None: + eco.finished_requests = finished_set + else: + engine_core_outputs[client_index] = EngineCoreOutputs( + finished_requests=finished_set) + finished_req_ids.clear() + + if engine_core_outputs: + # Return stats to only one of the front-ends. + next(iter(engine_core_outputs.values())).scheduler_stats = ( + self.make_stats(spec_decoding_stats)) return engine_core_outputs + def get_request_counts(self) -> tuple[int, int]: + """Returns (num_running_reqs, num_waiting_reqs).""" + return len(self.running), len(self.waiting) + def add_request(self, request: Request) -> None: self.waiting.append(request) self.requests[request.request_id] = request @@ -864,8 +883,11 @@ def _free_request(self, request: Request) -> Optional[dict[str, Any]]: delay_free_blocks, kv_xfer_params = self._connector_finished(request) self.encoder_cache_manager.free(request) - self._cached_reqs_data.pop(request.request_id, None) - self.finished_req_ids.add(request.request_id) + request_id = request.request_id + self._cached_reqs_data.pop(request_id, None) + self.finished_req_ids.add(request_id) + if self.finished_req_ids_dict is not None: + self.finished_req_ids_dict[request.client_index].add(request_id) if not delay_free_blocks: self._free_blocks(request) diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 122a5a72cc3..243c4eafe22 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -44,10 +44,6 @@ class EngineCoreRequest( omit_defaults=True, # type: ignore[call-arg] gc=False): # type: ignore[call-arg] - # NOTE: prompt and prompt_token_ids should be DecoderOnlyInput, - # but this object is currently not playing well with msgspec - # due to circular imports and typing we have in data.py - request_id: str prompt_token_ids: list[int] mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]] @@ -59,6 +55,8 @@ class EngineCoreRequest( lora_request: Optional[LoRARequest] cache_salt: Optional[str] + client_index: int = 0 + # Used in DP case to indicate which wave of requests this is expected to # belong to, to cover a race condition where the request is sent before # a wave finished notification is received. diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 0d646d8dd57..5fb00df80c7 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -25,7 +25,8 @@ from vllm.usage.usage_lib import UsageContext from vllm.utils import Device, cdiv from vllm.v1.engine import EngineCoreRequest -from vllm.v1.engine.core_client import AsyncMPClient, DPAsyncMPClient +from vllm.v1.engine.core_client import (AsyncMPClient, DPAsyncMPClient, + RayDPClient) from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError from vllm.v1.engine.output_processor import (OutputProcessor, RequestOutputCollector) @@ -34,6 +35,7 @@ from vllm.v1.executor.abstract import Executor from vllm.v1.metrics.loggers import (StatLoggerBase, StatLoggerFactory, setup_default_loggers) +from vllm.v1.metrics.prometheus import shutdown_prometheus from vllm.v1.metrics.stats import IterationStats, SchedulerStats logger = init_logger(__name__) @@ -52,6 +54,8 @@ def __init__( log_requests: bool = True, start_engine_loop: bool = True, stat_loggers: Optional[list[StatLoggerFactory]] = None, + client_addresses: Optional[dict[str, str]] = None, + client_index: int = 0, ) -> None: """ Create an AsyncLLM. @@ -111,14 +115,22 @@ def __init__( log_stats=self.log_stats) # EngineCore (starts the engine in background process). - core_client_class = AsyncMPClient if ( - vllm_config.parallel_config.data_parallel_size - == 1) else DPAsyncMPClient + core_client_class: Union[type[RayDPClient], type[DPAsyncMPClient], + type[AsyncMPClient]] + if vllm_config.parallel_config.data_parallel_size > 1: + if vllm_config.parallel_config.data_parallel_backend == "ray": + core_client_class = RayDPClient + else: + core_client_class = DPAsyncMPClient + else: + core_client_class = AsyncMPClient self.engine_core = core_client_class( vllm_config=vllm_config, executor_class=executor_class, log_stats=self.log_stats, + client_addresses=client_addresses, + client_index=client_index, ) if self.stat_loggers: for stat_logger in self.stat_loggers[0]: @@ -140,6 +152,8 @@ def from_vllm_config( stat_loggers: Optional[list[StatLoggerFactory]] = None, disable_log_requests: bool = False, disable_log_stats: bool = False, + client_addresses: Optional[dict[str, str]] = None, + client_index: int = 0, ) -> "AsyncLLM": if not envs.VLLM_USE_V1: raise ValueError( @@ -157,6 +171,8 @@ def from_vllm_config( log_requests=not disable_log_requests, log_stats=not disable_log_stats, usage_context=usage_context, + client_addresses=client_addresses, + client_index=client_index, ) @classmethod @@ -190,6 +206,8 @@ def __del__(self): def shutdown(self): """Shutdown, cleaning up the background proc and IPC.""" + shutdown_prometheus() + if engine_core := getattr(self, "engine_core", None): engine_core.shutdown() @@ -393,7 +411,6 @@ async def output_handler(): # TODO(rob): make into a coroutine and launch it in # background thread once Prometheus overhead is non-trivial. if stat_loggers: - assert outputs.scheduler_stats is not None AsyncLLM._record_stats( stat_loggers[outputs.engine_index], scheduler_stats=outputs.scheduler_stats, @@ -417,7 +434,7 @@ async def abort(self, request_id: str) -> None: @staticmethod def _record_stats( stat_loggers: list[StatLoggerBase], - scheduler_stats: SchedulerStats, + scheduler_stats: Optional[SchedulerStats], iteration_stats: Optional[IterationStats], ): """static so that it can be used from the output_handler task diff --git a/vllm/v1/engine/coordinator.py b/vllm/v1/engine/coordinator.py new file mode 100644 index 00000000000..fa5b13ba23c --- /dev/null +++ b/vllm/v1/engine/coordinator.py @@ -0,0 +1,212 @@ +# SPDX-License-Identifier: Apache-2.0 +import multiprocessing +import sys +import time +import weakref +from typing import Optional + +import msgspec.msgpack +import zmq + +from vllm.config import ParallelConfig +from vllm.logger import init_logger +from vllm.utils import get_mp_context, get_open_zmq_ipc_path, make_zmq_socket +from vllm.v1.engine import EngineCoreOutputs, EngineCoreRequestType +from vllm.v1.serial_utils import MsgpackDecoder +from vllm.v1.utils import get_engine_client_zmq_addr, shutdown + +logger = init_logger(__name__) + + +class DPCoordinator: + + def __init__(self, parallel_config: ParallelConfig): + + # Assume coordinator is colocated with front-end procs. + front_publish_address = get_open_zmq_ipc_path() + + dp_size = parallel_config.data_parallel_size + assert dp_size > 1, "Coordinator only used for data parallel" + + local_only = dp_size == parallel_config.data_parallel_size_local + host = parallel_config.data_parallel_master_ip + back_publish_address = get_engine_client_zmq_addr(local_only, host) + back_output_address = get_engine_client_zmq_addr(local_only, host) + + context = get_mp_context() + self.proc: multiprocessing.Process = context.Process( + target=CoordinatorProc.run_coordinator, + name="VLLM_DP_Coordinator", + kwargs={ + "engine_count": parallel_config.data_parallel_size, + "front_publish_address": front_publish_address, + "back_output_address": back_output_address, + "back_publish_address": back_publish_address, + }, + daemon=True) + self.proc.start() + + self.stats_publish_address = front_publish_address + self.coord_in_address = back_publish_address + self.coord_out_address = back_output_address + self._finalizer = weakref.finalize(self, shutdown, [self.proc]) + + def get_stats_publish_address(self) -> str: + return self.stats_publish_address + + def get_engine_socket_addresses(self) -> dict[str, str]: + return { + "coord_in_address": self.coord_in_address, + "coord_out_address": self.coord_out_address, + } + + def close(self): + self._finalizer() + + +class EngineState: + + def __init__(self): + self.request_counts = [0, 0] # [waiting, running] + + +class CoordinatorProc: + + def __init__(self, engine_count: int): + + self.ctx = zmq.Context() + + self.engines = [EngineState() for _ in range(engine_count)] + + self.current_wave = 0 + self.engines_running = False + self.stats_changed = False + + @staticmethod + def run_coordinator( + engine_count: int, + front_publish_address: str, + back_output_address: str, + back_publish_address: str, + ): + coordinator = CoordinatorProc(engine_count=engine_count) + + try: + coordinator.process_input_socket( + front_publish_address, + back_output_address, + back_publish_address, + ) + except KeyboardInterrupt: + logger.info("DP Coordinator process exiting") + + def process_input_socket(self, front_publish_address: str, + back_output_address: str, + back_publish_address: str): + + decoder = MsgpackDecoder(EngineCoreOutputs) + + with make_zmq_socket( + path=front_publish_address, # IPC + ctx=self.ctx, + socket_type=zmq.XPUB, + bind=True, + ) as publish_front, make_zmq_socket( + path=back_output_address, # IPC or TCP + ctx=self.ctx, + socket_type=zmq.PULL, + bind=True, + ) as output_back, make_zmq_socket( + path=back_publish_address, # IPC or TCP + ctx=self.ctx, + socket_type=zmq.XPUB, + bind=True, + ) as publish_back: + + poller = zmq.Poller() + poller.register(publish_front, zmq.POLLIN) + poller.register(output_back, zmq.POLLIN) + last_publish = 0 + while True: + elapsed = int(time.time() * 1000) - last_publish + wait_for = 100 if self.stats_changed else 3000 + events = poller.poll(timeout=max(0, wait_for - elapsed)) + if not events: + engine_list = self._get_engine_counts() + to_publish = (engine_list, self.current_wave, + self.engines_running) + msg = msgspec.msgpack.encode(to_publish) + publish_front.send(msg) + last_publish = int(time.time() * 1000) + self.stats_changed = False + continue + + events = dict(events) + + if publish_front in events: + buffer = publish_front.recv() + if buffer == b'\x01': + # Ignore subscription messages. + continue + engine_index, wave = msgspec.msgpack.decode(buffer) + if wave < self.current_wave: + engine_index = None + if not self.engines_running: + self.engines_running = True + self.stats_changed = True + self._send_start_wave(publish_back, self.current_wave, + engine_index) + + if output_back in events: + buffer = output_back.recv() + outputs: EngineCoreOutputs = decoder.decode(buffer) + + assert not outputs.outputs + assert outputs.utility_output is None + + eng_index = outputs.engine_index + if outputs.scheduler_stats: + stats = self.engines[eng_index].request_counts + stats[0] = outputs.scheduler_stats.num_waiting_reqs + stats[1] = outputs.scheduler_stats.num_running_reqs + self.stats_changed = True + + #TODO record prometheus metrics here? + + if outputs.wave_complete is not None: + if self.current_wave <= wave: + self.current_wave = wave + 1 + self.engines_running = False + self.stats_changed = True + elif outputs.start_wave is not None and ( + wave > self.current_wave or + (wave == self.current_wave + and not self.engines_running)): + # Engine received request for a non-current wave so + # we must ensure that other engines progress to the + # next wave. + self.current_wave = wave + self.engines_running = True + self.stats_changed = True + self._send_start_wave(publish_back, wave, eng_index) + + @staticmethod + def _send_start_wave(socket: zmq.Socket, wave: int, + exclude_engine_index: Optional[int]): + wave_encoded = msgspec.msgpack.encode((wave, exclude_engine_index)) + socket.send_multipart( + (EngineCoreRequestType.START_DP_WAVE.value, wave_encoded)) + + def _get_engine_counts(self) -> list[list[int]]: + return [e.request_counts for e in self.engines] + + def _get_engine_list(self) -> Optional[list[int]]: + shortlist: list[int] = [] + min_counts = [sys.maxsize, sys.maxsize] + for i, e in enumerate(self.engines): + if e.request_counts <= min_counts: + if e.request_counts < min_counts: + min_counts = e.request_counts + shortlist.clear() + shortlist.append(i) + return None if len(shortlist) == len(self.engines) else shortlist diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 0cf2383af1c..c44ccf8dfa4 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -7,6 +7,7 @@ import time from collections import deque from concurrent.futures import Future +from contextlib import ExitStack from inspect import isclass, signature from logging import DEBUG from typing import Any, Callable, Optional, TypeVar, Union @@ -22,7 +23,7 @@ from vllm.lora.request import LoRARequest from vllm.transformers_utils.config import ( maybe_register_config_serialize_by_value) -from vllm.utils import make_zmq_socket, resolve_obj_by_qualname, zmq_socket_ctx +from vllm.utils import make_zmq_socket, resolve_obj_by_qualname from vllm.v1.core.kv_cache_utils import (get_kv_cache_config, unify_kv_cache_configs) from vllm.v1.core.sched.interface import SchedulerInterface @@ -33,6 +34,7 @@ from vllm.v1.engine.mm_input_cache import MirroredProcessingCache from vllm.v1.executor.abstract import Executor from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder @@ -208,16 +210,13 @@ def execute_model(self, scheduler_output: SchedulerOutput): # Re-raise exception raise err - def step(self) -> EngineCoreOutputs: + def step(self) -> dict[int, EngineCoreOutputs]: """Schedule, execute, and make output.""" # Check for any requests remaining in the scheduler - unfinished, # or finished and not yet removed from the batch. if not self.scheduler.has_requests(): - return EngineCoreOutputs( - outputs=[], - scheduler_stats=self.scheduler.make_stats(), - ) + return {} scheduler_output = self.scheduler.schedule() model_output = self.execute_model(scheduler_output) engine_core_outputs = self.scheduler.update_from_output( @@ -225,7 +224,7 @@ def step(self) -> EngineCoreOutputs: return engine_core_outputs - def step_with_batch_queue(self) -> Optional[EngineCoreOutputs]: + def step_with_batch_queue(self) -> Optional[dict[int, EngineCoreOutputs]]: """Schedule and execute batches with the batch queue. Note that if nothing to output in this step, None is returned. @@ -267,8 +266,8 @@ def step_with_batch_queue(self) -> Optional[EngineCoreOutputs]: # Blocking until the first result is available. model_output = future.result() self.batch_queue.task_done() - engine_core_outputs = self.scheduler.update_from_output( - scheduler_output, model_output) + engine_core_outputs = (self.scheduler.update_from_output( + scheduler_output, model_output)) return engine_core_outputs @@ -346,7 +345,7 @@ def __init__( self, vllm_config: VllmConfig, on_head_node: bool, - input_address: str, + handshake_address: str, executor_class: type[Executor], log_stats: bool, engine_index: int = 0, @@ -359,15 +358,22 @@ def __init__( # Create input socket. input_ctx = zmq.Context() identity = engine_index.to_bytes(length=2, byteorder="little") - input_socket = make_zmq_socket(input_ctx, - input_address, - zmq.DEALER, - identity=identity, - bind=False) - try: + with make_zmq_socket(input_ctx, + handshake_address, + zmq.DEALER, + identity=identity, + linger=5000, + bind=False) as handshake_socket: + # Register engine with front-end. - output_address = self.startup_handshake( - input_socket, on_head_node, vllm_config.parallel_config) + addresses = self.startup_handshake(handshake_socket, on_head_node, + vllm_config.parallel_config) + input_addresses: list[str] = addresses["input_addresses"] + output_addresses: list[str] = addresses["output_addresses"] + coord_in_addr: Optional[str] = addresses.get("coord_in_address") + coord_out_addr: Optional[str] = addresses.get("coord_out_address") + self.client_count = len(output_addresses) + self.coordinator = coord_out_addr is not None # Update config which may have changed from the handshake. vllm_config.__post_init__() @@ -379,45 +385,44 @@ def __init__( super().__init__(vllm_config, executor_class, log_stats, executor_fail_callback) + self.engine_index = engine_index self.step_fn = (self.step if self.batch_queue is None else self.step_with_batch_queue) self.engines_running = False + self.last_counts = (0, 0) # Send ready message. num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks - input_socket.send( + handshake_socket.send( msgspec.msgpack.encode({ "status": "READY", "local": on_head_node, "num_gpu_blocks": num_gpu_blocks, })) - # Background Threads and Queues for IO. These enable us to - # overlap ZMQ socket IO with GPU since they release the GIL, - # and to overlap some serialization/deserialization with the - # model forward pass. - # Threads handle Socket <-> Queues and core_busy_loop uses Queue. - self.input_queue = input_queue - self.output_queue = queue.Queue[Union[EngineCoreOutputs, bytes]]() - threading.Thread(target=self.process_input_socket, - args=(input_socket, ), - daemon=True).start() - input_socket = None - self.output_thread = threading.Thread( - target=self.process_output_socket, - args=(output_address, engine_index), - daemon=True) - self.output_thread.start() - finally: - if input_socket is not None: - input_socket.close(linger=0) + # Background Threads and Queues for IO. These enable us to + # overlap ZMQ socket IO with GPU since they release the GIL, + # and to overlap some serialization/deserialization with the + # model forward pass. + # Threads handle Socket <-> Queues and core_busy_loop uses Queue. + self.input_queue = input_queue + self.output_queue = queue.Queue[Union[tuple[int, EngineCoreOutputs], + bytes]]() + threading.Thread(target=self.process_input_sockets, + args=(input_addresses, coord_in_addr, identity), + daemon=True).start() + self.output_thread = threading.Thread( + target=self.process_output_sockets, + args=(output_addresses, coord_out_addr, engine_index), + daemon=True) + self.output_thread.start() @staticmethod - def startup_handshake(input_socket: zmq.Socket, on_head_node: bool, - parallel_config: ParallelConfig) -> str: + def startup_handshake(handshake_socket: zmq.Socket, on_head_node: bool, + parallel_config: ParallelConfig) -> dict[str, Any]: # Send registration message. - input_socket.send( + handshake_socket.send( msgspec.msgpack.encode({ "status": "HELLO", "local": on_head_node, @@ -425,22 +430,19 @@ def startup_handshake(input_socket: zmq.Socket, on_head_node: bool, # Receive initialization message. logger.info("Waiting for init message from front-end.") - if not input_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60 * 1000): + if not handshake_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60_000): raise RuntimeError("Did not receive response from front-end " f"process within {HANDSHAKE_TIMEOUT_MINS} " f"minutes") - init_bytes = input_socket.recv() + init_bytes = handshake_socket.recv() init_message = msgspec.msgpack.decode(init_bytes) logger.debug("Received init message: %s", init_message) - output_socket_address = init_message["output_socket_address"] - #TBD(nick) maybe replace IP with configured head node address - - received_parallel_config = init_message["parallel_config"] + received_parallel_config = init_message.pop("parallel_config") for key, value in received_parallel_config.items(): setattr(parallel_config, key, value) - return output_socket_address + return init_message["addresses"] @staticmethod def run_engine_core(*args, @@ -532,9 +534,22 @@ def _process_engine_step(self): # Step the engine core. outputs = self.step_fn() + if not outputs: + return + # Put EngineCoreOutputs into the output queue. - if outputs is not None: - self.output_queue.put_nowait(outputs) + for output in outputs.items(): + self.output_queue.put_nowait(output) + + if self.coordinator: + # If there is a DP coordinator, publish our request counts + # (if they've changed) + counts = self.scheduler.get_request_counts() + if counts != self.last_counts: + self.last_counts = counts + stats = SchedulerStats(*counts) + self.output_queue.put_nowait( + (-1, EngineCoreOutputs(scheduler_stats=stats))) def _handle_client_request(self, request_type: EngineCoreRequestType, request: Any) -> None: @@ -545,7 +560,7 @@ def _handle_client_request(self, request_type: EngineCoreRequestType, elif request_type == EngineCoreRequestType.ABORT: self.abort_requests(request) elif request_type == EngineCoreRequestType.UTILITY: - call_id, method_name, args = request + client_idx, call_id, method_name, args = request output = UtilityOutput(call_id) try: method = getattr(self, method_name) @@ -556,7 +571,7 @@ def _handle_client_request(self, request_type: EngineCoreRequestType, output.failure_message = (f"Call to {method_name} method" f" failed: {str(e)}") self.output_queue.put_nowait( - EngineCoreOutputs(utility_output=output)) + (client_idx, EngineCoreOutputs(utility_output=output))) elif request_type == EngineCoreRequestType.EXECUTOR_FAILED: raise RuntimeError("Executor failed.") else: @@ -589,27 +604,68 @@ def _send_engine_dead(self): logger.fatal("vLLM shutdown signal from EngineCore failed " "to send. Please report this issue.") - def process_input_socket(self, input_socket: zmq.Socket): + def process_input_sockets(self, input_addresses: list[str], + coord_input_address: Optional[str], + identity: bytes): """Input socket IO thread.""" # Msgpack serialization decoding. add_request_decoder = MsgpackDecoder(EngineCoreRequest) generic_decoder = MsgpackDecoder() - while True: - # (RequestType, RequestData) - type_frame, *data_frames = input_socket.recv_multipart(copy=False) - request_type = EngineCoreRequestType(bytes(type_frame.buffer)) - - # Deserialize the request data. - decoder = add_request_decoder if ( - request_type == EngineCoreRequestType.ADD) else generic_decoder - request = decoder.decode(data_frames) - - # Push to input queue for core busy loop. - self.input_queue.put_nowait((request_type, request)) + with ExitStack() as stack, zmq.Context() as ctx: + input_sockets = [ + stack.enter_context( + make_zmq_socket(ctx, + input_address, + zmq.DEALER, + identity=identity, + bind=False)) + for input_address in input_addresses + ] + if coord_input_address is None: + coord_socket = None + else: + coord_socket = stack.enter_context( + make_zmq_socket(ctx, + coord_input_address, + zmq.XSUB, + identity=identity, + bind=False)) + # Send subscription message to coordinator. + coord_socket.send(b'\x01') + + # Register sockets with poller. + poller = zmq.Poller() + for input_socket in input_sockets: + # Send initial message to each input socket - this is required + # before the front-end ROUTER socket can send input messages + # back to us. + input_socket.send(b'') + poller.register(input_socket, zmq.POLLIN) + if coord_socket is not None: + poller.register(coord_socket, zmq.POLLIN) - def process_output_socket(self, output_path: str, engine_index: int): + while True: + for input_socket, _ in poller.poll(): + # (RequestType, RequestData) + type_frame, *data_frames = input_socket.recv_multipart( + copy=False) + request_type = EngineCoreRequestType( + bytes(type_frame.buffer)) + + # Deserialize the request data. + decoder = add_request_decoder if ( + request_type + == EngineCoreRequestType.ADD) else generic_decoder + request = decoder.decode(data_frames) + + # Push to input queue for core busy loop. + self.input_queue.put_nowait((request_type, request)) + + def process_output_sockets(self, output_paths: list[str], + coord_output_path: Optional[str], + engine_index: int): """Output socket IO thread.""" # Msgpack serialization encoding. @@ -623,30 +679,50 @@ def process_output_socket(self, output_path: str, engine_index: int): # We must set linger to ensure the ENGINE_CORE_DEAD # message is sent prior to closing the socket. - with zmq_socket_ctx(output_path, zmq.constants.PUSH, - linger=4000) as socket: + with ExitStack() as stack, zmq.Context() as ctx: + sockets = [ + stack.enter_context( + make_zmq_socket(ctx, output_path, zmq.PUSH, linger=4000)) + for output_path in output_paths + ] + coord_socket = stack.enter_context( + make_zmq_socket( + ctx, coord_output_path, zmq.PUSH, bind=False, + linger=4000)) if coord_output_path is not None else None + max_reuse_bufs = len(sockets) + 1 + while True: - outputs = self.output_queue.get() - if outputs == EngineCoreProc.ENGINE_CORE_DEAD: - socket.send(outputs, copy=False) + output = self.output_queue.get() + if output == EngineCoreProc.ENGINE_CORE_DEAD: + for socket in sockets: + socket.send(output) + #TODO also send to coordinator here? break - assert not isinstance(outputs, bytes) + assert not isinstance(output, bytes) + client_index, outputs = output outputs.engine_index = engine_index + if client_index == -1: + # Don't reuse buffer for coordinator message + # which will be very small. + assert coord_socket is not None + coord_socket.send_multipart(encoder.encode(outputs)) + continue + # Reclaim buffers that zmq is finished with. while pending and pending[-1][0].done: reuse_buffers.append(pending.pop()[2]) buffer = reuse_buffers.pop() if reuse_buffers else bytearray() buffers = encoder.encode_into(outputs, buffer) - tracker = socket.send_multipart(buffers, - copy=False, - track=True) + tracker = sockets[client_index].send_multipart(buffers, + copy=False, + track=True) if not tracker.done: ref = outputs if len(buffers) > 1 else None pending.appendleft((tracker, ref, buffer)) - elif len(reuse_buffers) < 2: - # Keep at most 2 buffers to reuse. + elif len(reuse_buffers) < max_reuse_bufs: + # Limit the number of buffers to reuse. reuse_buffers.append(buffer) @@ -658,7 +734,7 @@ def __init__( self, vllm_config: VllmConfig, on_head_node: bool, - input_address: str, + handshake_address: str, executor_class: type[Executor], log_stats: bool, ): @@ -673,10 +749,11 @@ def __init__( # Counts forward-passes of the model so that we can synchronize # finished with DP peers every N steps. self.counter = 0 + self.current_wave = 0 # Initialize the engine. dp_rank = vllm_config.parallel_config.data_parallel_rank - super().__init__(vllm_config, on_head_node, input_address, + super().__init__(vllm_config, on_head_node, handshake_address, executor_class, log_stats, dp_rank) def _init_data_parallel(self, vllm_config: VllmConfig): @@ -699,7 +776,6 @@ def _init_data_parallel(self, vllm_config: VllmConfig): self.local_dp_rank = local_dp_rank self.dp_group = vllm_config.parallel_config.stateless_init_dp_group() - self.current_wave = 0 def shutdown(self): super().shutdown() @@ -714,15 +790,16 @@ def add_request(self, request: EngineCoreRequest): # Request received for an already-completed wave, notify # front-end that we need to start the next one. self.output_queue.put_nowait( - EngineCoreOutputs(start_wave=self.current_wave)) + (-1, EngineCoreOutputs(start_wave=self.current_wave))) super().add_request(request) def _handle_client_request(self, request_type: EngineCoreRequestType, request: Any) -> None: if request_type == EngineCoreRequestType.START_DP_WAVE: - new_wave: int = request - if new_wave >= self.current_wave: + new_wave, exclude_eng_index = request + if exclude_eng_index != self.engine_index and ( + new_wave >= self.current_wave): self.current_wave = new_wave if not self.engines_running: logger.debug("EngineCore starting idle loop for wave %d.", @@ -775,7 +852,8 @@ def run_busy_loop(self): logger.debug("Wave %d finished, pausing engine loop.", self.current_wave) self.output_queue.put_nowait( - EngineCoreOutputs(wave_complete=self.current_wave)) + (-1, + EngineCoreOutputs(wave_complete=self.current_wave))) self.current_wave += 1 def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool: @@ -788,3 +866,108 @@ def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool: return ParallelConfig.has_unfinished_dp(self.dp_group, local_unfinished) + + +class DPEngineCoreActor(DPEngineCoreProc): + """ + Ray Actor for running EngineCore in a data parallel context + """ + + def __init__( + self, + vllm_config: VllmConfig, + on_head_node: bool, + addresses, + executor_class: type[Executor], + log_stats: bool, + engine_index: int = 0, + dp_rank: int = 0, + local_dp_rank: int = 0, + ): + # TODO(rui): improve shutdown handling + + # Ensure we can serialize transformer config after spawning + maybe_register_config_serialize_by_value() + + parallel_config: ParallelConfig = vllm_config.parallel_config + assert parallel_config.data_parallel_size > 1 or dp_rank > 0 + # Set data parallel rank for this engine process. + parallel_config.data_parallel_rank = dp_rank + parallel_config.data_parallel_rank_local = local_dp_rank + + input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]() + + executor_fail_callback = lambda: input_queue.put_nowait( + (EngineCoreRequestType.EXECUTOR_FAILED, b'')) + + input_addresses: list[str] = addresses["input_addresses"] + output_addresses: list[str] = addresses["output_addresses"] + coord_in_addr: Optional[str] = addresses.get("coord_in_address") + coord_out_addr: Optional[str] = addresses.get("coord_out_address") + self.client_count = len(output_addresses) + self.coordinator = coord_out_addr is not None + + # Ray sets CUDA_VISIBLE_DEVICES to empty string, + # we clean this up to be able to properly initialize + # data parallel groups. + del os.environ['CUDA_VISIBLE_DEVICES'] + # Set up data parallel environment. + self._init_data_parallel(vllm_config) + + # Counts forward-passes of the model so that we can synchronize + # finished with DP peers every N steps. + self.counter = 0 + self.current_wave = 0 + + # Initialize engine core and model. + EngineCore.__init__(self, vllm_config, executor_class, log_stats, + executor_fail_callback) + + self.engine_index = engine_index + self.step_fn = (self.step if self.batch_queue is None else + self.step_with_batch_queue) + self.engines_running = False + self.last_counts = (0, 0) + + # Background Threads and Queues for IO. These enable us to + # overlap ZMQ socket IO with GPU since they release the GIL, + # and to overlap some serialization/deserialization with the + # model forward pass. + # Threads handle Socket <-> Queues and core_busy_loop uses Queue. + self.input_queue = input_queue + self.output_queue = queue.Queue[Union[tuple[int, EngineCoreOutputs], + bytes]]() + identity = engine_index.to_bytes(length=2, byteorder="little") + threading.Thread(target=self.process_input_sockets, + args=(input_addresses, coord_in_addr, identity), + daemon=True).start() + self.output_thread = threading.Thread( + target=self.process_output_sockets, + args=(output_addresses, coord_out_addr, engine_index), + daemon=True) + self.output_thread.start() + + def wait_for_init(self): + """ + Wait until the engine core is initialized. + + This is just an empty method. When ray.get() on this method + (or any other method of the actor) returns, it is guaranteed + that actor creation (i.e., __init__) is complete. + """ + pass + + def run(self): + """ + Run the engine core busy loop. + """ + try: + self.run_busy_loop() + except SystemExit: + logger.debug("EngineCore exiting.") + raise + except Exception: + logger.exception("EngineCore encountered a fatal error.") + raise + finally: + self.shutdown() diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 0d52bc9a681..f52692ad478 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -2,6 +2,7 @@ import asyncio import contextlib import queue +import sys import uuid import weakref from abc import ABC, abstractmethod @@ -9,26 +10,28 @@ from collections.abc import Awaitable, Sequence from concurrent.futures import Future from dataclasses import dataclass -from enum import Enum, auto from threading import Thread from typing import Any, Callable, Optional, TypeVar, Union -import msgspec +import msgspec.msgpack import zmq import zmq.asyncio -from vllm.config import ParallelConfig, VllmConfig +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.utils import (get_open_port, get_open_zmq_inproc_path, - get_open_zmq_ipc_path, get_tcp_uri, make_zmq_socket) +from vllm.utils import (get_open_zmq_inproc_path, make_zmq_socket, + zmq_socket_ctx) from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, EngineCoreRequestType, UtilityOutput) +from vllm.v1.engine.coordinator import DPCoordinator from vllm.v1.engine.core import EngineCore, EngineCoreProc from vllm.v1.engine.exceptions import EngineDeadError from vllm.v1.executor.abstract import Executor from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr -from vllm.v1.utils import CoreEngineProcManager +from vllm.v1.utils import (CoreEngine, CoreEngineActorManager, + CoreEngineProcManager, get_engine_client_zmq_addr, + wait_for_engine_startup) logger = init_logger(__name__) @@ -36,8 +39,6 @@ _R = TypeVar('_R') # Return type for collective_rpc -STARTUP_POLL_PERIOD_MS = 10000 - class EngineCoreClient(ABC): """ @@ -67,7 +68,11 @@ def make_client( if multiprocess_mode and asyncio_mode: if vllm_config.parallel_config.data_parallel_size > 1: - return DPAsyncMPClient(vllm_config, executor_class, log_stats) + if vllm_config.parallel_config.data_parallel_backend == "ray": + return RayDPClient(vllm_config, executor_class, log_stats) + else: + return DPAsyncMPClient(vllm_config, executor_class, + log_stats) return AsyncMPClient(vllm_config, executor_class, log_stats) @@ -206,7 +211,7 @@ def __init__(self, *args, **kwargs): self.engine_core = EngineCore(*args, **kwargs) def get_output(self) -> EngineCoreOutputs: - return self.engine_core.step() + return self.engine_core.step().get(0) or EngineCoreOutputs() def add_request(self, request: EngineCoreRequest) -> None: self.engine_core.add_request(request) @@ -265,34 +270,19 @@ def collective_rpc(self, return self.engine_core.collective_rpc(method, timeout, args, kwargs) -class CoreEngineState(Enum): - NEW = auto() - CONNECTED = auto() - READY = auto() - - -class CoreEngine: - """One per data parallel rank.""" - - def __init__(self, index: int = 0, local: bool = True): - self.local = local - self.index = index - self.identity = index.to_bytes(length=2, byteorder="little") - - self.state = CoreEngineState.NEW - self.num_reqs_in_flight = 0 - - @dataclass class BackgroundResources: """Used as a finalizer for clean shutdown, avoiding circular reference back to the client object.""" ctx: Union[zmq.Context] - local_engine_manager: Optional[CoreEngineProcManager] = None + local_engine_manager: Optional[Union[CoreEngineProcManager, + CoreEngineActorManager]] = None + coordinator: Optional[DPCoordinator] = None output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None output_queue_task: Optional[asyncio.Task] = None + stats_update_task: Optional[asyncio.Task] = None shutdown_path: Optional[str] = None # Set if any of the engines are dead. Here so that the output @@ -305,9 +295,13 @@ def __call__(self): self.engine_dead = True if self.local_engine_manager is not None: self.local_engine_manager.close() + if self.coordinator is not None: + self.coordinator.close() if self.output_queue_task is not None: self.output_queue_task.cancel() + if self.stats_update_task is not None: + self.stats_update_task.cancel() # ZMQ context termination can hang if the sockets # aren't explicitly closed first. @@ -349,6 +343,7 @@ def __init__( vllm_config: VllmConfig, executor_class: type[Executor], log_stats: bool, + client_addresses: Optional[dict[str, str]] = None, ): self.vllm_config = vllm_config # Serialization setup. @@ -368,8 +363,8 @@ def __init__( try: parallel_config = vllm_config.parallel_config local_engine_count = parallel_config.data_parallel_size_local - start_index = parallel_config.data_parallel_rank local_start_index = parallel_config.data_parallel_rank_local + dp_size = parallel_config.data_parallel_size # SPMD mode is where there is an LLM instance per DP rank and # one core engine per LLM, see @@ -381,22 +376,86 @@ def __init__( CoreEngine(index=local_start_index, local=True) ] else: - assert start_index == 0 + assert parallel_config.data_parallel_rank == 0 local_start_index = 0 self.core_engines = [ CoreEngine(index=i, local=(i < local_engine_count)) - for i in range(parallel_config.data_parallel_size) + for i in range(dp_size) ] - input_address, output_address = self._get_zmq_addresses( - parallel_config, spmd_mode) + local_only = spmd_mode or local_engine_count == dp_size + + self.stats_update_address: Optional[str] = None + if client_addresses is not None: + input_address = client_addresses["input_address"] + output_address = client_addresses["output_address"] + self.stats_update_address = client_addresses.get( + "stats_update_address") + else: + host = parallel_config.data_parallel_master_ip + input_address = get_engine_client_zmq_addr(local_only, host) + output_address = get_engine_client_zmq_addr(local_only, host) # Create input and output sockets. self.input_socket = self.resources.input_socket = make_zmq_socket( self.ctx, input_address, zmq.ROUTER, bind=True) - self.resources.output_socket = make_zmq_socket( - self.ctx, output_address, zmq.constants.PULL) + self.ctx, output_address, zmq.PULL) + + if client_addresses is None: + self._init_engines_direct(vllm_config, local_only, + local_start_index, input_address, + output_address, executor_class, + log_stats) + coordinator = self.resources.coordinator + if coordinator: + self.stats_update_address = ( + coordinator.get_stats_publish_address()) + + # Wait for ready messages from each engine on the input socket. + identities = set(e.identity for e in self.core_engines) + sync_input_socket = zmq.Socket.shadow(self.input_socket) + while identities: + if not sync_input_socket.poll(timeout=600_000): + raise TimeoutError("Timed out waiting for engines to send" + "initial message on input socket.") + identity, _ = sync_input_socket.recv_multipart() + identities.remove(identity) + + self.core_engine = self.core_engines[0] + self.utility_results: dict[int, AnyFuture] = {} + + # Request objects which may contain pytorch-allocated tensors + # that we need to keep references to until zmq is done with the + # underlying data. + self.pending_messages = deque[tuple[zmq.MessageTracker, Any]]() + + success = True + finally: + if not success: + self._finalizer() + + def _init_engines_direct(self, vllm_config: VllmConfig, local_only: bool, + local_start_index: int, input_address: str, + output_address: str, + executor_class: type[Executor], log_stats: bool): + """Self-contained client mode, launch engine and coordinator process + as needed.""" + + parallel_config = vllm_config.parallel_config + local_engine_count = parallel_config.data_parallel_size_local + start_index = parallel_config.data_parallel_rank + host = parallel_config.data_parallel_master_ip + + if len(self.core_engines) > 1: + self.resources.coordinator = DPCoordinator(parallel_config) + + handshake_address = get_engine_client_zmq_addr( + local_only, host, parallel_config.data_parallel_rpc_port) + + with zmq_socket_ctx(handshake_address, zmq.ROUTER, + bind=True) as handshake_socket: + # Start local engines. if local_engine_count: # In server mode, start_index and local_start_index will @@ -406,139 +465,42 @@ def __init__( vllm_config=vllm_config, executor_class=executor_class, log_stats=log_stats, - input_address=input_address, + handshake_address=handshake_address, on_head_node=True, local_engine_count=local_engine_count, start_index=start_index, local_start_index=local_start_index) - self.core_engine = self.core_engines[0] - # Wait for engine core process(es) to start. - self._wait_for_engine_startup(output_address, parallel_config) + self._wait_for_engine_startup(handshake_socket, input_address, + output_address) - self.utility_results: dict[int, AnyFuture] = {} + def _wait_for_engine_startup(self, handshake_socket: zmq.Socket, + input_address: str, output_address: str): + addresses: dict[str, Any] = { + "input_addresses": [input_address], + "output_addresses": [output_address], + } - # Request objects which may contain pytorch-allocated tensors - # that we need to keep references to until zmq is done with the - # underlying data. - self.pending_messages = deque[tuple[zmq.MessageTracker, Any]]() + coordinator = self.resources.coordinator + if coordinator is not None: + addresses.update(coordinator.get_engine_socket_addresses()) - success = True - finally: - if not success: - self._finalizer() - - @staticmethod - def _get_zmq_addresses(parallel_config: ParallelConfig, - spmd_mode: bool) -> tuple[str, str]: - """Returns (input_address, output_address).""" - dp_size = parallel_config.data_parallel_size - local_engine_count = parallel_config.data_parallel_size_local - - if local_engine_count == dp_size or spmd_mode: - input_address = get_open_zmq_ipc_path() - output_address = get_open_zmq_ipc_path() - else: - host = parallel_config.data_parallel_master_ip - input_port = parallel_config.data_parallel_rpc_port - output_port = get_open_port() - input_address = get_tcp_uri(host, input_port) - output_address = get_tcp_uri(host, output_port) - - return input_address, output_address - - def _wait_for_engine_startup(self, output_address: str, - parallel_config: ParallelConfig): - # Get a sync handle to the socket which can be sync or async. - sync_input_socket = zmq.Socket.shadow(self.input_socket) - - # Wait for engine core process(es) to send ready messages. - local_count = parallel_config.data_parallel_size_local - remote_count = len(self.core_engines) - local_count - # [local, remote] counts - conn_pending, start_pending = [local_count, remote_count], [0, 0] - - poller = zmq.Poller() - poller.register(sync_input_socket, zmq.POLLIN) proc_manager = self.resources.local_engine_manager if proc_manager is not None: - for sentinel in proc_manager.sentinels(): - poller.register(sentinel, zmq.POLLIN) - while any(conn_pending) or any(start_pending): - events = poller.poll(STARTUP_POLL_PERIOD_MS) - if not events: - if any(conn_pending): - logger.debug( - "Waiting for %d local, %d remote core engine proc(s) " - "to connect.", *conn_pending) - if any(start_pending): - logger.debug( - "Waiting for %d local, %d remote core engine proc(s) " - "to start.", *start_pending) - continue - if len(events) > 1 or events[0][0] != sync_input_socket: - # One of the local core processes exited. - finished = proc_manager.finished_procs( - ) if proc_manager else {} - raise RuntimeError("Engine core initialization failed. " - "See root cause above. " - f"Failed core proc(s): {finished}") - - # Receive HELLO and READY messages from the input socket. - eng_identity, ready_msg_bytes = sync_input_socket.recv_multipart() - eng_index = int.from_bytes(eng_identity, byteorder="little") - engine = next( - (e for e in self.core_engines if e.identity == eng_identity), - None) - if engine is None: - raise RuntimeError(f"Message from engine with unexpected data " - f"parallel rank: {eng_index}") - msg = msgspec.msgpack.decode(ready_msg_bytes) - status, local = msg["status"], msg["local"] - if local != engine.local: - raise RuntimeError(f"{status} message from " - f"{'local' if local else 'remote'} " - f"engine {eng_index}, expected it to be " - f"{'local' if engine.local else 'remote'}") - - if status == "HELLO" and engine.state == CoreEngineState.NEW: - - # Send init message with DP config info. - init_message = self.encoder.encode({ - "output_socket_address": output_address, - "parallel_config": { - "data_parallel_master_ip": - parallel_config.data_parallel_master_ip, - "data_parallel_master_port": - parallel_config.data_parallel_master_port, - "data_parallel_size": - parallel_config.data_parallel_size, - }, - }) - sync_input_socket.send_multipart((eng_identity, *init_message), - copy=False) - conn_pending[0 if local else 1] -= 1 - start_pending[0 if local else 1] += 1 - engine.state = CoreEngineState.CONNECTED - elif status == "READY" and (engine.state - == CoreEngineState.CONNECTED): - # Setup KV cache config with initialization state from - # engine core process. Sum values from all engines in DP case. - cache_config = self.vllm_config.cache_config - num_gpu_blocks = cache_config.num_gpu_blocks or 0 - num_gpu_blocks += msg['num_gpu_blocks'] - cache_config.num_gpu_blocks = num_gpu_blocks - - start_pending[0 if local else 1] -= 1 - engine.state = CoreEngineState.READY - else: - raise RuntimeError(f"Unexpected {status} message for " - f"{'local' if local else 'remote'} engine " - f"{eng_index} in {engine.state} state.") - - logger.debug("%s from %s core engine process %s.", status, - "local" if local else "remote", eng_index) + assert isinstance(proc_manager, CoreEngineProcManager), ( + "_wait_for_engine_startup should only be " + "called with CoreEngineProcManager") + + wait_for_engine_startup( + handshake_socket, + addresses, + self.core_engines, + self.vllm_config.parallel_config, + self.vllm_config.cache_config, + proc_manager, + coordinator.proc if coordinator else None, + ) def shutdown(self): # Terminate background resources. @@ -604,8 +566,8 @@ def process_outputs_socket(): try: shutdown_socket.bind(shutdown_path) poller = zmq.Poller() - poller.register(shutdown_socket) - poller.register(out_socket) + poller.register(shutdown_socket, zmq.POLLIN) + poller.register(out_socket, zmq.POLLIN) while True: socks = poller.poll() if not socks: @@ -667,7 +629,7 @@ def call_utility(self, method: str, *args) -> Any: future: Future[Any] = Future() self.utility_results[call_id] = future self._send_input(EngineCoreRequestType.UTILITY, - (call_id, method, args)) + (0, call_id, method, args)) return future.result() @@ -729,15 +691,21 @@ def save_sharded_state(self, class AsyncMPClient(MPClient): """Asyncio-compatible client for multi-proc EngineCore.""" - def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor], - log_stats: bool): + def __init__(self, + vllm_config: VllmConfig, + executor_class: type[Executor], + log_stats: bool, + client_addresses: Optional[dict[str, str]] = None, + client_index: int = 0): super().__init__( asyncio_mode=True, vllm_config=vllm_config, executor_class=executor_class, log_stats=log_stats, + client_addresses=client_addresses, ) + self.client_index = client_index self.outputs_queue = asyncio.Queue[Union[EngineCoreOutputs, Exception]]() try: @@ -853,12 +821,13 @@ async def _call_utility_async(self, method: str, *args, future = asyncio.get_running_loop().create_future() self.utility_results[call_id] = future message = (EngineCoreRequestType.UTILITY.value, *self.encoder.encode( - (call_id, method, args))) + (self.client_index, call_id, method, args))) await self._send_input_message(message, engine, args) self._ensure_output_queue_task() return await future async def add_request_async(self, request: EngineCoreRequest) -> None: + request.client_index = self.client_index await self._send_input(EngineCoreRequestType.ADD, request) self._ensure_output_queue_task() @@ -920,17 +889,119 @@ class DPAsyncMPClient(AsyncMPClient): """Asyncio-compatible client for multi-proc, multi-engine (data parallel) EngineCore.""" - def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor], - log_stats: bool): + def __init__(self, + vllm_config: VllmConfig, + executor_class: type[Executor], + log_stats: bool, + client_addresses: Optional[dict[str, str]] = None, + client_index: int = 0): self.current_wave = 0 self.engines_running = False + # To route aborts to the correct engine. self.reqs_in_flight: dict[str, CoreEngine] = {} - super().__init__(vllm_config, executor_class, log_stats) + super().__init__(vllm_config, executor_class, log_stats, + client_addresses, client_index) assert len(self.core_engines) > 1 + # List of [waiting, running] pair per engine. + self.lb_engines: list[list[int]] = [] + + self.first_req_sock_addr = get_open_zmq_inproc_path() + self.first_req_send_socket = make_zmq_socket(self.ctx, + self.first_req_sock_addr, + zmq.PAIR, + bind=True) + try: + # If we are running in an asyncio event loop, start the stats task. + # Otherwise, it will be started lazily. + asyncio.get_running_loop() + self._ensure_stats_update_task() + except RuntimeError: + pass + + def _ensure_stats_update_task(self): + resources = self.resources + if resources.stats_update_task is not None: + return + + assert self.stats_update_address is not None + + async def run_engine_stats_update_task(): + with make_zmq_socket(self.ctx, self.stats_update_address, + zmq.XSUB) as socket, make_zmq_socket( + self.ctx, + self.first_req_sock_addr, + zmq.PAIR, + bind=False) as first_req_rcv_socket: + # Send subscription message. + await socket.send(b'\x01') + + poller = zmq.asyncio.Poller() + poller.register(socket, zmq.POLLIN) + poller.register(first_req_rcv_socket, zmq.POLLIN) + + while True: + events = await poller.poll() + if not self.engines_running and len(events) == 2 or ( + events[0][0] == first_req_rcv_socket): + # Send a message to notify the coordinator that + # we're sending a request while the engines are + # paused, so that it can wake the others up + # (to run dummy EP loop). + self.engines_running = True + buf = first_req_rcv_socket.recv( + flags=zmq.NOBLOCK).result() + target_eng_index = int.from_bytes(buf, "little") + msg = msgspec.msgpack.encode( + (target_eng_index, self.current_wave)) + await socket.send(msg) + + buf = None + while True: + # Drain all stats events (we only care about latest). + future: asyncio.Future[bytes] = socket.recv( + flags=zmq.NOBLOCK) + if isinstance(future.exception(), zmq.Again): + break + buf = future.result() + if buf is None: + continue + + # Update local load-balancing state. + counts, wave, running = msgspec.msgpack.decode(buf) + self.current_wave = wave + self.engines_running = running + self.lb_engines = counts + + resources.stats_update_task = asyncio.create_task( + run_engine_stats_update_task()) + + def get_core_engine_for_request(self) -> CoreEngine: + if not self.lb_engines: + return self.core_engines[0] + # TODO use P2C alg for larger DP sizes + num_engines = len(self.lb_engines) + min_counts = [sys.maxsize, sys.maxsize] + eng_index = 0 + for i in range(num_engines): + # Start from client_index for to help with balancing when + # engines are empty. + idx = (self.client_index + i) % num_engines + counts = self.lb_engines[idx] + if counts < min_counts: + min_counts = counts + eng_index = idx + # Adjust local counts for better balancing between stats updates + # from the coordinator (these are overwritten 10x per second). + if min_counts[0]: + min_counts[0] += 1 + else: + min_counts[1] += 1 + return self.core_engines[eng_index] + async def call_utility_async(self, method: str, *args) -> Any: # Only the result from the first engine is returned. return (await asyncio.gather(*[ @@ -939,62 +1010,30 @@ async def call_utility_async(self, method: str, *args) -> Any: ]))[0] async def add_request_async(self, request: EngineCoreRequest) -> None: + self._ensure_stats_update_task() + request.current_wave = self.current_wave + request.client_index = self.client_index chosen_engine = self.get_core_engine_for_request() self.reqs_in_flight[request.request_id] = chosen_engine - chosen_engine.num_reqs_in_flight += 1 to_await = self._send_input(EngineCoreRequestType.ADD, request, chosen_engine) if not self.engines_running: - # Send request to chosen engine and dp start loop - # control message to all other engines. - self.engines_running = True - to_await = asyncio.gather( - to_await, # type: ignore[assignment] - *self._start_wave_coros(exclude_index=chosen_engine.index)) + # Notify coordinator that we're sending a request + await self.first_req_send_socket.send(chosen_engine.identity) await to_await self._ensure_output_queue_task() - def get_core_engine_for_request(self) -> CoreEngine: - return min(self.core_engines, key=lambda e: e.num_reqs_in_flight) - @staticmethod async def process_engine_outputs(self: "DPAsyncMPClient", outputs: EngineCoreOutputs): - if self.reqs_in_flight: - for req_id in outputs.finished_requests or (): - if engine := self.reqs_in_flight.pop(req_id, None): - engine.num_reqs_in_flight -= 1 - - if outputs.wave_complete is not None: - # Current wave is complete, move to next wave number - # and mark engines as paused. - if self.current_wave <= outputs.wave_complete: - self.current_wave = outputs.wave_complete + 1 - self.engines_running = False - - elif outputs.start_wave is not None and ( - outputs.start_wave > self.current_wave or - (outputs.start_wave == self.current_wave - and not self.engines_running)): - # Engine received request for a non-current wave so we must ensure - # that other engines progress to the next wave. - self.current_wave = outputs.start_wave - self.engines_running = True - await asyncio.gather(*self._start_wave_coros( - exclude_index=outputs.engine_index)) - - def _start_wave_coros(self, exclude_index: int) -> list[Awaitable[None]]: - logger.debug("Sending start DP wave %d.", self.current_wave) - return [ - self._send_input(EngineCoreRequestType.START_DP_WAVE, - self.current_wave, engine) - for engine in self.core_engines if engine.index != exclude_index - ] + if outputs.finished_requests and self.reqs_in_flight: + for req_id in outputs.finished_requests: + self.reqs_in_flight.pop(req_id, None) async def abort_requests_async(self, request_ids: list[str]) -> None: if not request_ids: @@ -1018,3 +1057,168 @@ async def _abort_requests(self, request_ids: list[str], if not self.resources.engine_dead: await self._send_input(EngineCoreRequestType.ABORT, request_ids, engine) + + +class RayDPClient(DPAsyncMPClient): + """ + Ray-based client for multi-proc, multi-engine (data parallel) + EngineCore. + """ + + def __init__( + self, + vllm_config: VllmConfig, + executor_class: type[Executor], + log_stats: bool, + client_addresses: Optional[dict[str, str]] = None, + client_index: int = 0, + ): + self.client_index = client_index + self.current_wave = 0 + self.engines_running = False + self.reqs_in_flight: dict[str, CoreEngine] = {} + + self.vllm_config = vllm_config + # Serialization setup. + self.encoder = MsgpackEncoder() + self.decoder = MsgpackDecoder(EngineCoreOutputs) + + # ZMQ setup. + sync_ctx = zmq.Context(io_threads=2) + self.ctx = zmq.asyncio.Context(sync_ctx) + + # List of [waiting, running] pair per engine. + self.lb_engines: list[list[int]] = [] + self.first_req_sock_addr = get_open_zmq_inproc_path() + self.first_req_send_socket = make_zmq_socket(self.ctx, + self.first_req_sock_addr, + zmq.PAIR, + bind=True) + + # This will ensure resources created so far are closed + # when the client is garbage collected, even if an + # exception is raised mid-construction. + self.resources = BackgroundResources(ctx=sync_ctx) + self._finalizer = weakref.finalize(self, self.resources) + success = False + try: + parallel_config = vllm_config.parallel_config + local_engine_count = parallel_config.data_parallel_size_local + start_index = parallel_config.data_parallel_rank + local_start_index = parallel_config.data_parallel_rank_local + + # SPMD mode is where there is an LLM instance per DP rank and + # one core engine per LLM, see + # examples/offline_inference/data_parallel.py. + spmd_mode = local_start_index is not None + if spmd_mode: + assert local_engine_count == 1 + self.core_engines = [ + CoreEngine(index=local_start_index, local=True) + ] + else: + assert start_index == 0 + local_start_index = 0 + self.core_engines = [ + CoreEngine(index=i, local=(i < local_engine_count)) + for i in range(parallel_config.data_parallel_size) + ] + + dp_size = parallel_config.data_parallel_size + local_only = spmd_mode or local_engine_count == dp_size + + self.stats_update_address: Optional[str] = None + if client_addresses is not None: + input_address = client_addresses["input_address"] + output_address = client_addresses["output_address"] + self.stats_update_address = client_addresses.get( + "stats_update_address") + else: + host = parallel_config.data_parallel_master_ip + input_address = get_engine_client_zmq_addr(local_only, host) + output_address = get_engine_client_zmq_addr(local_only, host) + + # Create input and output sockets. + self.input_socket = self.resources.input_socket = make_zmq_socket( + self.ctx, input_address, zmq.ROUTER, bind=True) + self.resources.output_socket = make_zmq_socket( + self.ctx, output_address, zmq.PULL) + + if client_addresses is None: + self._init_engines_direct(vllm_config, local_only, + local_start_index, input_address, + output_address, executor_class, + log_stats) + coordinator = self.resources.coordinator + if coordinator: + self.stats_update_address = \ + coordinator.get_stats_publish_address() + + # Wait for ready messages from each engine on the input socket. + identities = set(e.identity for e in self.core_engines) + sync_input_socket = zmq.Socket.shadow(self.input_socket) + while identities: + if not sync_input_socket.poll(timeout=600_000): + raise TimeoutError("Timed out waiting for engines to send" + "initial message on input socket.") + identity, _ = sync_input_socket.recv_multipart() + identities.remove(identity) + + self.core_engine = self.core_engines[0] + + self.utility_results: dict[int, AnyFuture] = {} + + # Request objects which may contain pytorch-allocated tensors + # that we need to keep references to until zmq is done with the + # underlying data. + self.pending_messages = deque[tuple[zmq.MessageTracker, Any]]() + self.outputs_queue = asyncio.Queue[Union[EngineCoreOutputs, + Exception]]() + + success = True + finally: + if not success: + self._finalizer() + + try: + # If we are running in an asyncio event loop, start the queue task. + # Otherwise, it will be started lazily. If it is not started here, + # we could miss EXECUTOR_FAILED messages from engine core if they + # occur prior to any requests being sent. + asyncio.get_running_loop() + self._ensure_output_queue_task() + except RuntimeError: + pass + + def _init_engines_direct(self, vllm_config: VllmConfig, local_only: bool, + local_start_index: int, input_address: str, + output_address: str, + executor_class: type[Executor], log_stats: bool): + """Self-contained client mode, launch engine and coordinator process + as needed.""" + + parallel_config = vllm_config.parallel_config + local_engine_count = parallel_config.data_parallel_size_local + start_index = parallel_config.data_parallel_rank + + if len(self.core_engines) > 1: + self.resources.coordinator = DPCoordinator(parallel_config) + + addresses: dict[str, Any] = { + "input_addresses": [input_address], + "output_addresses": [output_address], + } + + coordinator = self.resources.coordinator + if coordinator is not None: + addresses.update(coordinator.get_engine_socket_addresses()) + + # Start all engines. + self.resources.local_engine_manager = CoreEngineActorManager( + vllm_config=vllm_config, + executor_class=executor_class, + log_stats=log_stats, + addresses=addresses, + local_engine_count=local_engine_count, + start_index=start_index, + local_start_index=local_start_index) diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index 2b75a3a2ecb..4fa23b74ea5 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -12,13 +12,12 @@ from vllm.logger import init_logger from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics from vllm.v1.engine import FinishReason +from vllm.v1.metrics.prometheus import unregister_vllm_metrics from vllm.v1.metrics.stats import IterationStats, SchedulerStats from vllm.v1.spec_decode.metrics import SpecDecodingLogging, SpecDecodingProm logger = init_logger(__name__) -_LOCAL_LOGGING_INTERVAL_SEC = 5.0 - StatLoggerFactory = Callable[[VllmConfig, int], "StatLoggerBase"] @@ -35,7 +34,7 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): ... @abstractmethod - def record(self, scheduler_stats: SchedulerStats, + def record(self, scheduler_stats: Optional[SchedulerStats], iteration_stats: Optional[IterationStats]): ... @@ -78,20 +77,22 @@ def _get_throughput(self, tracked_stats: list[int], now: float) -> float: # Compute summary metrics for tracked stats return float(np.sum(tracked_stats) / (now - self.last_log_time)) - def record(self, scheduler_stats: SchedulerStats, + def record(self, scheduler_stats: Optional[SchedulerStats], iteration_stats: Optional[IterationStats]): """Log Stats to standard output.""" if iteration_stats: self._track_iteration_stats(iteration_stats) - self.prefix_caching_metrics.observe(scheduler_stats.prefix_cache_stats) + if scheduler_stats is not None: + self.prefix_caching_metrics.observe( + scheduler_stats.prefix_cache_stats) - if scheduler_stats.spec_decoding_stats is not None: - self.spec_decoding_logging.observe( - scheduler_stats.spec_decoding_stats) + if scheduler_stats.spec_decoding_stats is not None: + self.spec_decoding_logging.observe( + scheduler_stats.spec_decoding_stats) - self.last_scheduler_stats = scheduler_stats + self.last_scheduler_stats = scheduler_stats def log(self): now = time.monotonic() @@ -131,10 +132,11 @@ def log(self): self.spec_decoding_logging.log(log_fn=log_fn) def log_engine_initialized(self): - logger.info( - "vllm cache_config_info with initialization " \ - "after num_gpu_blocks is: %d", - self.vllm_config.cache_config.num_gpu_blocks) + if self.vllm_config.cache_config.num_gpu_blocks: + logger.info( + "Engine %03d: vllm cache_config_info with initialization " + "after num_gpu_blocks is: %d", self.engine_index, + self.vllm_config.cache_config.num_gpu_blocks) class PrometheusStatLogger(StatLoggerBase): @@ -144,7 +146,8 @@ class PrometheusStatLogger(StatLoggerBase): _spec_decoding_cls = SpecDecodingProm def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): - self._unregister_vllm_metrics() + + unregister_vllm_metrics() self.vllm_config = vllm_config self.engine_index = engine_index # Use this flag to hide metrics that were deprecated in @@ -169,11 +172,13 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): self.gauge_scheduler_running = self._gauge_cls( name="vllm:num_requests_running", documentation="Number of requests in model execution batches.", + multiprocess_mode="mostrecent", labelnames=labelnames).labels(*labelvalues) self.gauge_scheduler_waiting = self._gauge_cls( name="vllm:num_requests_waiting", documentation="Number of requests waiting to be processed.", + multiprocess_mode="mostrecent", labelnames=labelnames).labels(*labelvalues) # @@ -182,6 +187,7 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): self.gauge_gpu_cache_usage = self._gauge_cls( name="vllm:gpu_cache_usage_perc", documentation="GPU KV-cache usage. 1 means 100 percent usage.", + multiprocess_mode="mostrecent", labelnames=labelnames).labels(*labelvalues) self.counter_gpu_prefix_cache_queries = self._counter_cls( @@ -242,6 +248,9 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): buckets=build_1_2_5_buckets(max_model_len), labelnames=labelnames).labels(*labelvalues) + # TODO: This metric might be incorrect in case of using multiple + # api_server counts which uses prometheus mp. + # See: https://github.com/vllm-project/vllm/pull/18053 self.histogram_iteration_tokens = \ self._histogram_cls( name="vllm:iteration_tokens_total", @@ -340,6 +349,9 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): # # LoRA metrics # + + # TODO: This metric might be incorrect in case of using multiple + # api_server counts which uses prometheus mp. self.gauge_lora_info: Optional[prometheus_client.Gauge] = None if vllm_config.lora_config is not None: self.labelname_max_lora = "max_lora" @@ -350,13 +362,16 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): self._gauge_cls( name="vllm:lora_requests_info", documentation="Running stats on lora requests.", + multiprocess_mode="sum", labelnames=[ self.labelname_max_lora, self.labelname_waiting_lora_adapters, self.labelname_running_lora_adapters, - ]) + ], + ) def log_metrics_info(self, type: str, config_obj: SupportsMetricsInfo): + metrics_info = config_obj.metrics_info() metrics_info["engine"] = self.engine_index @@ -372,25 +387,28 @@ def log_metrics_info(self, type: str, config_obj: SupportsMetricsInfo): info_gauge = self._gauge_cls( name=name, documentation=documentation, - labelnames=metrics_info.keys()).labels(**metrics_info) + multiprocess_mode="mostrecent", + labelnames=metrics_info.keys(), + ).labels(**metrics_info) info_gauge.set(1) - def record(self, scheduler_stats: SchedulerStats, + def record(self, scheduler_stats: Optional[SchedulerStats], iteration_stats: Optional[IterationStats]): """Log to prometheus.""" - self.gauge_scheduler_running.set(scheduler_stats.num_running_reqs) - self.gauge_scheduler_waiting.set(scheduler_stats.num_waiting_reqs) + if scheduler_stats is not None: + self.gauge_scheduler_running.set(scheduler_stats.num_running_reqs) + self.gauge_scheduler_waiting.set(scheduler_stats.num_waiting_reqs) - self.gauge_gpu_cache_usage.set(scheduler_stats.gpu_cache_usage) + self.gauge_gpu_cache_usage.set(scheduler_stats.gpu_cache_usage) - self.counter_gpu_prefix_cache_queries.inc( - scheduler_stats.prefix_cache_stats.queries) - self.counter_gpu_prefix_cache_hits.inc( - scheduler_stats.prefix_cache_stats.hits) + self.counter_gpu_prefix_cache_queries.inc( + scheduler_stats.prefix_cache_stats.queries) + self.counter_gpu_prefix_cache_hits.inc( + scheduler_stats.prefix_cache_stats.hits) - if scheduler_stats.spec_decoding_stats is not None: - self.spec_decoding_prom.observe( - scheduler_stats.spec_decoding_stats) + if scheduler_stats.spec_decoding_stats is not None: + self.spec_decoding_prom.observe( + scheduler_stats.spec_decoding_stats) if iteration_stats is None: return @@ -445,13 +463,6 @@ def record(self, scheduler_stats: SchedulerStats, self.gauge_lora_info.labels(**lora_info_labels)\ .set_to_current_time() - @staticmethod - def _unregister_vllm_metrics(): - # Unregister any existing vLLM collectors (for CI/CD - for collector in list(prometheus_client.REGISTRY._collector_to_names): - if hasattr(collector, "_name") and "vllm" in collector._name: - prometheus_client.REGISTRY.unregister(collector) - def log_engine_initialized(self): self.log_metrics_info("cache_config", self.vllm_config.cache_config) diff --git a/vllm/v1/metrics/prometheus.py b/vllm/v1/metrics/prometheus.py new file mode 100644 index 00000000000..f1256853536 --- /dev/null +++ b/vllm/v1/metrics/prometheus.py @@ -0,0 +1,77 @@ +# SPDX-License-Identifier: Apache-2.0 + +import os +import tempfile +from typing import Optional + +from prometheus_client import REGISTRY, CollectorRegistry, multiprocess + +from vllm.logger import init_logger + +logger = init_logger(__name__) + +# Global temporary directory for prometheus multiprocessing +_prometheus_multiproc_dir: Optional[tempfile.TemporaryDirectory] = None + + +def setup_multiprocess_prometheus(): + """Set up prometheus multiprocessing directory if not already configured. + + """ + global _prometheus_multiproc_dir + + if "PROMETHEUS_MULTIPROC_DIR" not in os.environ: + # Make TemporaryDirectory for prometheus multiprocessing + # Note: global TemporaryDirectory will be automatically + # cleaned up upon exit. + _prometheus_multiproc_dir = tempfile.TemporaryDirectory() + os.environ["PROMETHEUS_MULTIPROC_DIR"] = _prometheus_multiproc_dir.name + logger.debug("Created PROMETHEUS_MULTIPROC_DIR at %s", + _prometheus_multiproc_dir.name) + else: + logger.warning("Found PROMETHEUS_MULTIPROC_DIR was set by user. " + "This directory must be wiped between vLLM runs or " + "you will find inaccurate metrics. Unset the variable " + "and vLLM will properly handle cleanup.") + + +def get_prometheus_registry(): + """Get the appropriate prometheus registry based on multiprocessing + configuration. + + Returns: + Registry: A prometheus registry + """ + if os.getenv("PROMETHEUS_MULTIPROC_DIR") is not None: + logger.debug("Using multiprocess registry for prometheus metrics") + registry = CollectorRegistry() + multiprocess.MultiProcessCollector(registry) + return registry + + return REGISTRY + + +def unregister_vllm_metrics(): + """Unregister any existing vLLM collectors from the prometheus registry. + + This is useful for testing and CI/CD where metrics may be registered + multiple times across test runs. + + Also, in case of multiprocess, we need to unregister the metrics from the + global registry. + """ + registry = REGISTRY + # Unregister any existing vLLM collectors + for collector in list(registry._collector_to_names): + if hasattr(collector, "_name") and "vllm" in collector._name: + registry.unregister(collector) + + +def shutdown_prometheus(): + """Shutdown prometheus metrics.""" + try: + pid = os.getpid() + multiprocess.mark_process_dead(pid) + logger.debug("Marked Prometheus metrics for process %d as dead", pid) + except Exception as e: + logger.error("Error during metrics cleanup: %s", str(e)) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index d1cdd2c5275..45b0459ca9b 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -20,18 +20,19 @@ class Request: def __init__( self, request_id: str, + client_index: int, prompt_token_ids: list[int], multi_modal_inputs: Optional[list[MultiModalKwargs]], multi_modal_hashes: Optional[list[str]], multi_modal_placeholders: Optional[list[PlaceholderRange]], sampling_params: SamplingParams, eos_token_id: Optional[int], - arrival_time: float, lora_request: Optional["LoRARequest"] = None, structured_output_request: Optional["StructuredOutputRequest"] = None, cache_salt: Optional[str] = None, ) -> None: self.request_id = request_id + self.client_index = client_index self.sampling_params = sampling_params # Because of LoRA, the eos token id can be different for each request. self.eos_token_id = eos_token_id @@ -86,13 +87,13 @@ def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": return cls( request_id=request.request_id, + client_index=request.client_index, prompt_token_ids=request.prompt_token_ids, multi_modal_inputs=request.mm_inputs, multi_modal_hashes=request.mm_hashes, multi_modal_placeholders=request.mm_placeholders, sampling_params=request.sampling_params, eos_token_id=request.eos_token_id, - arrival_time=request.arrival_time, lora_request=request.lora_request, structured_output_request=StructuredOutputRequest( sampling_params=request.sampling_params), diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 0758747a83c..1470dad5b2b 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -1,31 +1,40 @@ # SPDX-License-Identifier: Apache-2.0 -import os +import argparse +import multiprocessing import time import weakref from collections import defaultdict from collections.abc import Sequence +from enum import Enum, auto from multiprocessing import Process, connection -from typing import (TYPE_CHECKING, Callable, Generic, Optional, TypeVar, Union, - overload) +from multiprocessing.process import BaseProcess +from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar, + Union, overload) +import msgspec import torch +import zmq -from vllm.config import VllmConfig +from vllm.config import CacheConfig, ParallelConfig, VllmConfig from vllm.logger import init_logger from vllm.model_executor.models.utils import extract_layer_index from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) -from vllm.utils import get_mp_context, kill_process_tree +from vllm.utils import (get_mp_context, get_open_port, get_open_zmq_ipc_path, + get_tcp_uri, kill_process_tree) from vllm.v1.executor.abstract import Executor if TYPE_CHECKING: from vllm.attention.layer import Attention + from vllm.v1.engine.coordinator import DPCoordinator logger = init_logger(__name__) T = TypeVar("T") +STARTUP_POLL_PERIOD_MS = 10000 + class ConstantList(Generic[T], Sequence): @@ -95,6 +104,78 @@ def __repr__(self): return f"ConstantList({self._x})" +def get_engine_client_zmq_addr(local_only: bool, + host: str, + port: int = 0) -> str: + return get_open_zmq_ipc_path() if local_only else (get_tcp_uri( + host, port or get_open_port())) + + +class APIServerProcessManager: + """Manages a group of API server processes. + + Handles creation, monitoring, and termination of API server worker + processes. Also monitors extra processes to check if they are healthy. + """ + + def __init__( + self, + target_server_fn: Callable, + listen_address: str, + sock: Any, + args: argparse.Namespace, + num_servers: int, + input_addresses: list[str], + output_addresses: list[str], + stats_update_address: Optional[str] = None, + ): + """Initialize and start API server worker processes. + + Args: + target_server_fn: Function to call for each API server process + listen_address: Address to listen for client connections + sock: Socket for client connections + args: Command line arguments + num_servers: Number of API server processes to start + input_addresses: Input addresses for each API server + output_addresses: Output addresses for each API server + stats_update_address: Optional stats update address + """ + self.listen_address = listen_address + self.sock = sock + self.args = args + + # Start API servers + spawn_context = multiprocessing.get_context("spawn") + self.processes: list[BaseProcess] = [] + + for i, in_addr, out_addr in zip(range(num_servers), input_addresses, + output_addresses): + client_config = { + "input_address": in_addr, + "output_address": out_addr, + "client_index": i + } + if stats_update_address is not None: + client_config["stats_update_address"] = stats_update_address + + proc = spawn_context.Process(target=target_server_fn, + name=f"ApiServer_{i}", + args=(listen_address, sock, args, + client_config)) + self.processes.append(proc) + proc.start() + + logger.info("Started %d API server processes", len(self.processes)) + + # Shutdown only the API server processes on garbage collection + # The extra processes are managed by their owners + self._finalizer = weakref.finalize(self, shutdown, self.processes) + + def close(self) -> None: + self._finalizer() + + class CoreEngineProcManager: """ Utility class to handle creation, readiness, and shutdown @@ -109,7 +190,7 @@ def __init__( local_start_index: int, vllm_config: VllmConfig, on_head_node: bool, - input_address: str, + handshake_address: str, executor_class: type[Executor], log_stats: bool, ): @@ -117,12 +198,12 @@ def __init__( common_kwargs = { "vllm_config": vllm_config, "on_head_node": on_head_node, - "input_address": input_address, + "handshake_address": handshake_address, "executor_class": executor_class, "log_stats": log_stats, } - self.processes: list[Process] = [] + self.processes: list[BaseProcess] = [] for index in range(local_engine_count): local_index = local_start_index + index global_index = start_index + index @@ -135,8 +216,7 @@ def __init__( "local_dp_rank": local_index, })) - self._finalizer = weakref.finalize(self, shutdown, self.processes, - input_address) + self._finalizer = weakref.finalize(self, shutdown, self.processes) try: for proc in self.processes: proc.start() @@ -164,9 +244,357 @@ def finished_procs(self) -> dict[str, int]: } +class CoreEngineActorManager: + """ + Utility class to handle creation, readiness, and shutdown + of core engine Ray actors used by the AsyncLLM and LLMEngine. + + Different from CoreEngineProcManager, this class manages + core engines for both local and remote nodes. + """ + + def __init__( + self, + local_engine_count: int, + start_index: int, + local_start_index: int, + vllm_config: VllmConfig, + addresses, + executor_class: type[Executor], + log_stats: bool, + ): + import copy + + import ray + from ray._private.state import available_resources_per_node + from ray.util.scheduling_strategies import ( + PlacementGroupSchedulingStrategy) + from ray.util.state import list_nodes + + from vllm.v1.engine.core import DPEngineCoreActor + + self.local_engine_actors: list[ray.ActorHandle] = [] + self.remote_engine_actors: list[ray.ActorHandle] = [] + + dp_size = vllm_config.parallel_config.data_parallel_size + remote_engine_count = dp_size - local_engine_count + + if ray.is_initialized(): + logger.info( + "Ray is already initialized. Skipping Ray initialization.") + else: + ray.init() + + nodes = list_nodes() + available_resources_by_id = available_resources_per_node() + available_resources_by_ip = {} + num_workers = vllm_config.parallel_config.world_size + + dp_size_available = 0 + for node in nodes: + node_ip = node.node_ip + node_id = node.node_id + node_resources = available_resources_by_id[node_id] + available_resources_by_ip[node_ip] = node_resources + # For now, each DP rank can only be assigned to one node + # TODO(rui): support allocating a single DP rank to multiple nodes + dp_size_available += node_resources["GPU"] // num_workers + + assert dp_size_available >= dp_size, ( + "Not enough resources to allocate DP ranks") + + head_node_ip = \ + vllm_config.parallel_config.data_parallel_master_ip + + refs = [] + for index in range(local_engine_count): + local_index = local_start_index + index + global_index = start_index + index + dp_vllm_config = copy.deepcopy(vllm_config) + bundles = [{ + "GPU": 1.0, + "node:" + head_node_ip: 0.001 + }] * num_workers + [{ + "CPU": 1.0 + }] + pg = ray.util.placement_group( + name=f"dp_rank_{global_index}", + strategy="STRICT_PACK", + bundles=bundles, + ) + dp_vllm_config.parallel_config.placement_group = pg + actor = ray.remote(DPEngineCoreActor).options( + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg, + placement_group_bundle_index=num_workers, + )).remote(vllm_config=dp_vllm_config, + executor_class=executor_class, + log_stats=log_stats, + addresses=addresses, + on_head_node=True, + engine_index=global_index, + dp_rank=global_index, + local_dp_rank=local_index) + self.local_engine_actors.append(actor) + refs.append(actor.wait_for_init.remote()) + + for index in range(remote_engine_count): + local_index = index + global_index = local_engine_count + index + bundles = [{"GPU": 1.0}] * num_workers + [{"CPU": 1.0}] + pg = ray.util.placement_group( + name=f"dp_rank_{global_index}", + strategy="STRICT_PACK", + bundles=bundles, + ) + dp_vllm_config = copy.deepcopy(vllm_config) + dp_vllm_config.parallel_config.placement_group = pg + actor = ray.remote(DPEngineCoreActor).options( + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg, + placement_group_bundle_index=num_workers, + )).remote(vllm_config=dp_vllm_config, + executor_class=executor_class, + log_stats=log_stats, + addresses=addresses, + on_head_node=False, + engine_index=global_index, + dp_rank=global_index, + local_dp_rank=local_index) + self.remote_engine_actors.append(actor) + refs.append(actor.wait_for_init.remote()) + + ray.get(refs) + for actor in self.local_engine_actors + self.remote_engine_actors: + actor.run.remote() + + def close(self): + import ray + for actor in self.local_engine_actors + self.remote_engine_actors: + ray.kill(actor) + + +class CoreEngineState(Enum): + NEW = auto() + CONNECTED = auto() + READY = auto() + + +class CoreEngine: + """One per data parallel rank.""" + + def __init__(self, index: int = 0, local: bool = True): + self.local = local + self.index = index + self.identity = index.to_bytes(2, "little") + + self.state = CoreEngineState.NEW + + +def wait_for_engine_startup( + handshake_socket: zmq.Socket, + addresses: dict[str, Any], + core_engines: list[CoreEngine], + parallel_config: ParallelConfig, + cache_config: CacheConfig, + proc_manager: Optional[CoreEngineProcManager], + coord_process: Optional[Process], +): + + # Wait for engine core process(es) to send ready messages. + local_count = parallel_config.data_parallel_size_local + remote_count = len(core_engines) - local_count + # [local, remote] counts + conn_pending, start_pending = [local_count, remote_count], [0, 0] + poller = zmq.Poller() + poller.register(handshake_socket, zmq.POLLIN) + + if proc_manager is not None: + for sentinel in proc_manager.sentinels(): + poller.register(sentinel, zmq.POLLIN) + if coord_process is not None: + poller.register(coord_process.sentinel, zmq.POLLIN) + while any(conn_pending) or any(start_pending): + events = poller.poll(STARTUP_POLL_PERIOD_MS) + if not events: + if any(conn_pending): + logger.debug( + "Waiting for %d local, %d remote core engine proc(s) " + "to connect.", *conn_pending) + if any(start_pending): + logger.debug( + "Waiting for %d local, %d remote core engine proc(s) " + "to start.", *start_pending) + continue + if len(events) > 1 or events[0][0] != handshake_socket: + # One of the local core processes exited. + finished = proc_manager.finished_procs() if proc_manager else {} + if coord_process is not None and coord_process.exitcode is not None: + finished[coord_process.name] = coord_process.exitcode + raise RuntimeError("Engine core initialization failed. " + "See root cause above. " + f"Failed core proc(s): {finished}") + + # Receive HELLO and READY messages from the input socket. + eng_identity, ready_msg_bytes = handshake_socket.recv_multipart() + eng_index = int.from_bytes(eng_identity, "little") + engine = next((e for e in core_engines if e.identity == eng_identity), + None) + if engine is None: + raise RuntimeError(f"Message from engine with unexpected data " + f"parallel rank: {eng_index}") + msg = msgspec.msgpack.decode(ready_msg_bytes) + status, local = msg["status"], msg["local"] + if local != engine.local: + raise RuntimeError(f"{status} message from " + f"{'local' if local else 'remote'} " + f"engine {eng_index}, expected it to be " + f"{'local' if engine.local else 'remote'}") + + if status == "HELLO" and engine.state == CoreEngineState.NEW: + + # Send init message with DP config info. + init_message = msgspec.msgpack.encode({ + "addresses": addresses, + "parallel_config": { + "data_parallel_master_ip": + parallel_config.data_parallel_master_ip, + "data_parallel_master_port": + parallel_config.data_parallel_master_port, + "data_parallel_size": parallel_config.data_parallel_size, + }, + }) + handshake_socket.send_multipart((eng_identity, init_message), + copy=False) + conn_pending[0 if local else 1] -= 1 + start_pending[0 if local else 1] += 1 + engine.state = CoreEngineState.CONNECTED + elif status == "READY" and (engine.state == CoreEngineState.CONNECTED): + # Setup KV cache config with initialization state from + # engine core process. Sum values from all engines in DP case. + num_gpu_blocks = cache_config.num_gpu_blocks or 0 + num_gpu_blocks += msg['num_gpu_blocks'] + cache_config.num_gpu_blocks = num_gpu_blocks + + start_pending[0 if local else 1] -= 1 + engine.state = CoreEngineState.READY + else: + raise RuntimeError(f"Unexpected {status} message for " + f"{'local' if local else 'remote'} engine " + f"{eng_index} in {engine.state} state.") + + logger.debug("%s from %s core engine process %s.", status, + "local" if local else "remote", eng_index) + + +def wait_for_completion_or_failure( + api_server_manager: APIServerProcessManager, + local_engine_manager: Optional[CoreEngineProcManager] = None, + coordinator: Optional["DPCoordinator"] = None) -> None: + """Wait for all processes to complete or detect if any fail. + + Raises an exception if any process exits with a non-zero status. + """ + + try: + logger.info("Waiting for API servers to complete ...") + # Create a mapping of sentinels to their corresponding processes + # for efficient lookup + sentinel_to_proc: dict[Any, BaseProcess] = { + proc.sentinel: proc + for proc in api_server_manager.processes + } + + if coordinator: + sentinel_to_proc[coordinator.proc.sentinel] = coordinator.proc + + if local_engine_manager: + for proc in local_engine_manager.processes: + sentinel_to_proc[proc.sentinel] = proc + + # Check if any process terminates + while sentinel_to_proc: + # Wait for any process to terminate + ready_sentinels: list[Any] = connection.wait(sentinel_to_proc) + + # Process any terminated processes + for sentinel in ready_sentinels: + proc = sentinel_to_proc.pop(sentinel) + + # Check if process exited with error + if proc.exitcode != 0: + raise RuntimeError( + f"Process {proc.name} (PID: {proc.pid}) " + f"died with exit code {proc.exitcode}") + except KeyboardInterrupt: + logger.info("Received KeyboardInterrupt, shutting down API servers...") + except Exception as e: + logger.exception("Exception occurred while running API servers: %s", + str(e)) + raise + finally: + logger.info("Terminating remaining processes ...") + api_server_manager.close() + if coordinator: + coordinator.close() + if local_engine_manager: + local_engine_manager.close() + + +def wait_for_ray_engine_actors( + api_server_manager: APIServerProcessManager, + engine_actor_manager: CoreEngineActorManager, + coordinator: Optional["DPCoordinator"] = None) -> None: + """Wait for all ray engine actors to complete or detect if any fail. + + Raises an exception if any process exits with a non-zero status. + """ + + try: + logger.info("Waiting for ray engine actors to complete ...") + # Create a mapping of sentinels to their corresponding processes + # for efficient lookup + sentinel_to_proc: dict[Any, Union[SpawnProcess, Process]] = { + proc.sentinel: proc + for proc in api_server_manager.processes + } + + if coordinator: + sentinel_to_proc.update( + {coordinator.proc.sentinel: coordinator.proc}) + + # TODO(rui): check if any ray engine actor terminates + # Check if any process terminates + while sentinel_to_proc: + # Wait for any process to terminate + ready_sentinels: list[Any] = connection.wait(sentinel_to_proc) + + # Process any terminated processes + for sentinel in ready_sentinels: + proc = sentinel_to_proc.pop(sentinel) + + # Check if process exited with error + if proc.exitcode != 0: + raise RuntimeError( + f"Process {proc.name} (PID: {proc.pid}) " + f"died with exit code {proc.exitcode}") + except KeyboardInterrupt: + logger.info("Received KeyboardInterrupt, shutting down API servers...") + except Exception as e: + logger.exception("Exception occurred while running API servers: %s", + str(e)) + raise + finally: + logger.info("Terminating remaining processes ...") + api_server_manager.close() + if coordinator: + coordinator.close() + engine_actor_manager.close() + + # Note(rob): shutdown function cannot be a bound method, -# else the gc cannot collect the objedecoupct. -def shutdown(procs: list[Process], input_address: str): +# else the gc cannot collect the object. +def shutdown(procs: list[BaseProcess]): # Shutdown the process. for proc in procs: if proc.is_alive(): @@ -185,12 +613,6 @@ def shutdown(procs: list[Process], input_address: str): if proc.is_alive() and (pid := proc.pid) is not None: kill_process_tree(pid) - # Remove zmq ipc socket files. - if input_address.startswith("ipc://"): - socket_file = input_address[len("ipc://"):] - if os and os.path.exists(socket_file): - os.remove(socket_file) - def bind_kv_cache( kv_caches: dict[str, torch.Tensor],