Skip to content

[Misc] Better RayExecutor and multiprocessing compatibility #14705

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Mar 21, 2025
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1240,6 +1240,14 @@ def create_engine_config(
cpu_offload_gb=self.cpu_offload_gb,
calculate_kv_scales=self.calculate_kv_scales,
)

try:
import ray

placement_group = ray.util.get_current_placement_group()
except ImportError:
placement_group = None

parallel_config = ParallelConfig(
pipeline_parallel_size=self.pipeline_parallel_size,
tensor_parallel_size=self.tensor_parallel_size,
Expand All @@ -1252,6 +1260,7 @@ def create_engine_config(
self.tokenizer_pool_extra_config,
),
ray_workers_use_nsight=self.ray_workers_use_nsight,
placement_group=placement_group,
distributed_executor_backend=self.distributed_executor_backend,
worker_cls=self.worker_cls,
worker_extension_cls=self.worker_extension_cls,
Expand Down
4 changes: 2 additions & 2 deletions vllm/executor/multiproc_worker_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils import _check_multiproc_method, get_mp_context, run_method
from vllm.utils import _maybe_force_spawn, get_mp_context, run_method

logger = init_logger(__name__)

Expand Down Expand Up @@ -291,7 +291,7 @@ def set_multiprocessing_worker_envs(parallel_config):
in a multiprocessing environment. This should be called by the parent
process before worker processes are created"""

_check_multiproc_method()
_maybe_force_spawn()

# Configure thread parallelism if OMP_NUM_THREADS isn't set
#
Expand Down
21 changes: 13 additions & 8 deletions vllm/executor/ray_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,8 +284,9 @@ def initialize_ray_cluster(
assert_ray_available()
from vllm.platforms import current_platform

# Connect to a ray cluster.
if current_platform.is_rocm() or current_platform.is_xpu():
if ray.is_initialized():
logger.info("Ray is already initialized. Skipping Ray initialization.")
elif current_platform.is_rocm() or current_platform.is_xpu():
# Try to connect existing ray instance and create a new one if not found
try:
ray.init("auto", ignore_reinit_error=True)
Expand All @@ -299,19 +300,21 @@ def initialize_ray_cluster(
else:
ray.init(address=ray_address, ignore_reinit_error=True)

if parallel_config.placement_group:
# Placement group is already set.
return

device_str = current_platform.ray_device_key
if not device_str:
raise ValueError(
f"current platform {current_platform.device_name} does not "
"support ray.")

# Create placement group for worker processes
current_placement_group = ray.util.get_current_placement_group()
# Create or get the placement group for worker processes
if parallel_config.placement_group:
current_placement_group = parallel_config.placement_group
else:
current_placement_group = ray.util.get_current_placement_group()

if current_placement_group:
logger.info("Using the existing placement group")

# We are in a placement group
bundles = current_placement_group.bundle_specs
# Verify that we can use the placement group.
Expand All @@ -331,6 +334,8 @@ def initialize_ray_cluster(
f"Required number of devices: {parallel_config.world_size}. "
f"Total number of devices: {device_bundles}.")
else:
logger.info("No current placement group found. "
"Creating a new placement group.")
num_devices_in_cluster = ray.cluster_resources().get(device_str, 0)
# Log a warning message and delay resource allocation failure response.
# Avoid immediate rejection to allow user-initiated placement group
Expand Down
47 changes: 37 additions & 10 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2142,20 +2142,47 @@ def zmq_socket_ctx(path: str, socket_type: Any) -> Iterator[zmq.Socket]:
ctx.destroy(linger=0)


def _check_multiproc_method():
if (cuda_is_initialized()
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"):
logger.warning("CUDA was previously initialized. We must use "
"the `spawn` multiprocessing start method. Setting "
"VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. "
"See https://docs.vllm.ai/en/latest/getting_started/"
"troubleshooting.html#python-multiprocessing "
"for more information.")
def ray_is_initialized():
"""Check if Ray is initialized."""

try:
import ray
return ray.is_initialized()
except ImportError:
return False


def _maybe_force_spawn():
"""Check if we need to force the use of the `spawn` multiprocessing start
method.
"""
if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") == "spawn":
return

reason = None
if cuda_is_initialized():
reason = "CUDA is initialized"
elif ray_is_initialized():
reason = "Ray is initialized and Ray process can only be spawned"

if reason is not None:
logger.warning(
"We must use the `spawn` multiprocessing start method. "
"Overriding VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. "
"See https://docs.vllm.ai/en/latest/getting_started/"
"troubleshooting.html#python-multiprocessing "
"for more information. Reason: %s", reason)
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"


def get_mp_context():
_check_multiproc_method()
"""Get a multiprocessing context with a particular method (spawn or fork).
By default we follow the value of the VLLM_WORKER_MULTIPROC_METHOD to
determine the multiprocessing method (default is fork). However, under
certain conditions, we may enforce spawn and override the value of
VLLM_WORKER_MULTIPROC_METHOD.
"""
_maybe_force_spawn()
mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
return multiprocessing.get_context(mp_method)

Expand Down