Skip to content

[Misc] Better RayExecutor and multiprocessing compatibility #14705

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Mar 21, 2025

Conversation

comaniac
Copy link
Collaborator

@comaniac comaniac commented Mar 12, 2025

One issue that has been frequently reported recently is hanging when Ray Executor is used. Specifically, it happens when

  1. Create a Ray placement group.
  2. Create a Ray actor using the first bundle of the placement group.
  3. Launch a vLLM engine in the actor.

Note that the step 1-2 are usually done by other Ray libraries or applications such as Ray Serve and Ray Data.

After diving into the implementations, we conclude that it is because vLLM by default uses fork to create engine process.

  1. However, it is not the best practice to fork a child process in a Ray actor as it results in undefined behavior. For example, this hanging issue happens because the forked child process tries to access the Ray GCS but it's in a different process.
  2. Accordingly, we have to use spawn when creating child processes. However, since the spawn child process has a fresh environment, it cannot get the placement group. As a result, we have to pass placement group object from the main process.

This PR fixes the issue by:

  1. Enforce spawn when Ray is initialized (so that we are likely in an actor).
  2. Get the current placement group before creating the engine (in the main process), and pass the placement group object to the spawn processes.

Example code:

from typing import Any, Dict
import logging
import os

from fastapi import FastAPI
from starlette.requests import Request

import ray
from ray import serve

logger = logging.getLogger("ray.serve")

app = FastAPI()

ray.init(
    runtime_env=dict(
        env_vars=dict(
            VLLM_USE_V1="1"
        )
    )
)

@serve.deployment(
    autoscaling_config={
        "initial_replicas": 1,
        "min_replicas": 1,
        "max_replicas": 1,
        "target_ongoing_requests": 5,
    },
    max_ongoing_requests=10,
)
@serve.ingress(app)
class VLLMDeployment:
    def __init__(
        self,
        engine_args: Dict[str, Any],
    ):
        import vllm
        assert vllm.envs.VLLM_USE_V1

        engine_args = vllm.AsyncEngineArgs(
            **engine_args,
            distributed_executor_backend="ray",
            disable_log_requests=True,
        )
        self.engine = vllm.AsyncLLMEngine.from_engine_args(
            engine_args=engine_args,
        )
        logger.info(f"Engine initialized")

    @app.post("/generate")
    async def generate(
        self, raw_request: Request
    ):
        import vllm
        import uuid

        request = await raw_request.json()
        stream = self.engine.generate(
            request_id=str(uuid.uuid4()),
            prompt=request["prompt"],
            sampling_params=vllm.SamplingParams(**request["params"]),
        )
        async for request_output in stream:
            if request_output.finished:
                return request_output


def build_app(cli_args: Dict[str, str]) -> serve.Application:
    pg_resources = [
        {"CPU": 1}, {"CPU": 1, "GPU": 1},
    ]
    return VLLMDeployment.options(
        placement_group_bundles=pg_resources,
        placement_group_strategy="STRICT_PACK",
    ).bind(cli_args)

cc @ruisearch42 @youkaichao

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label Mar 12, 2025
@kouroshHakha
Copy link
Contributor

Two features that are very important here are 1) being able to pass placement groups directly to engine_args instead of hacking the vllm_config around 2) forcing to use spawn when ray is getting used.

Question for vllm committers, is there any good reason for forking the process than always using spawn? cc @WoosukKwon

@comaniac
Copy link
Collaborator Author

Two features that are very important here are 1) being able to pass placement groups directly to engine_args instead of hacking the vllm_config around 2) forcing to use spawn when ray is getting used.

Question for vllm committers, is there any good reason for forking the process than always using spawn? cc @WoosukKwon

@russellb mentioned to me that spawn will break some existing code using vLLM as a library and he will share more details later.

Copy link
Collaborator

@ruisearch42 ruisearch42 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM. We will need to better understand why currently vLLM prefers fork than spawn.

@youkaichao
Copy link
Member

Two features that are very important here are 1) being able to pass placement groups directly to engine_args instead of hacking the vllm_config around 2) forcing to use spawn when ray is getting used.
Question for vllm committers, is there any good reason for forking the process than always using spawn? cc @WoosukKwon

@russellb mentioned to me that spawn will break some existing code using vLLM as a library and he will share more details later.

spawn does not work if users do not have if __name__ == "__main__" set correctly, e.g. if people just run it in a python shell:

import vllm
vllm.LLM()

spawn does not work for jupyter notebook either, see https://stackoverflow.com/questions/48846085/python-multiprocessing-within-jupyter-notebook .

when we know it's safe to use spawn, we can use spawn instead of fork, e.g. when we are creating an api server, we have explicit if __name__ == "__main__", so we can use spawn by default.

if __name__ == "__main__":

w.r.t ray, since fork does not work for ray, we can use spawn if we find the current process is a ray actor. that's totally fine.

@youkaichao
Copy link
Member

thanks for pointing it out, i think this is exactly the same issue I'm trying to solve in #14410 .

To summarize, we can:

  • change VLLM_WORKER_MULTIPROC_METHOD to spawn if we are in a ray actor (help needed from ray team: how to accurately tell if the current process is a ray actor)
  • pass the current placement group to the spawned process (help needed from the ray team: how to reliably get the current placement group, is the placement group safe to serialize)

@ruisearch42
Copy link
Collaborator

@comaniac , I think what @youkaichao said makes sense. I created a PR here: #14768

@kouroshHakha
Copy link
Contributor

@youkaichao fork() apparently has had long standing issues in various contexts and in python 3.14 they are switching it to spawn() as default. python/cpython#84559

@ruisearch42
Copy link
Collaborator

@kouroshHakha looks like vLLM documents the tradeoffs: https://docs.vllm.ai/en/latest/design/multiprocessing.html

Signed-off-by: Cody Yu <[email protected]>
Signed-off-by: Cody Yu <[email protected]>
Signed-off-by: Cody Yu <[email protected]>
Signed-off-by: Cody Yu <[email protected]>
Signed-off-by: Cody Yu <[email protected]>
Signed-off-by: Cody Yu <[email protected]>
Signed-off-by: Cody Yu <[email protected]>
Signed-off-by: Cody Yu <[email protected]>
Signed-off-by: Cody Yu <[email protected]>
Signed-off-by: Cody Yu <[email protected]>
Signed-off-by: Cody Yu <[email protected]>
@mergify mergify bot added the documentation Improvements or additions to documentation label Mar 19, 2025
@comaniac
Copy link
Collaborator Author

comaniac commented Mar 19, 2025

Per offline discussion with @youkaichao, here is the latest behavior in a spawn process:

  1. When parallel_config.placement_group is given, then use it.
    2. Otherwise if RAY_PLACEMENT_GROUP is given, then use it.
  2. Otherwise use ray.util.get_current_placement_group(), and use it if available.
  3. Otherwise creates a new placement group.

Note that since environment variable can only be strings, we create utilities to serialize/deserialize placement groups to strings, but these utilities should be implemented in Ray.

cc @ruisearch42 @kouroshHakha @richardliaw

This reverts commit 6907571.

Signed-off-by: Cody Yu <[email protected]>
Copy link
Member

@youkaichao youkaichao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the fix!

@comaniac comaniac added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 19, 2025
Signed-off-by: Cody Yu <[email protected]>
@comaniac comaniac enabled auto-merge (squash) March 19, 2025 18:16
Signed-off-by: Cody Yu <[email protected]>
@vllm-bot vllm-bot merged commit 5df2da5 into vllm-project:main Mar 21, 2025
34 of 36 checks passed
erictang000 pushed a commit to erictang000/vllm that referenced this pull request Mar 25, 2025
lulmer pushed a commit to lulmer/vllm that referenced this pull request Apr 7, 2025
nishith-fujitsu pushed a commit to nishith-fujitsu/vllm that referenced this pull request Apr 9, 2025
shreyankg pushed a commit to shreyankg/vllm that referenced this pull request May 3, 2025
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation force-merge ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants