Skip to content

Commit b02f62d

Browse files
comaniacMu Huai
authored and
Mu Huai
committed
[Misc] Better RayExecutor and multiprocessing compatibility (vllm-project#14705)
Signed-off-by: Cody Yu <[email protected]> Signed-off-by: Mu Huai <[email protected]>
1 parent f47b2d3 commit b02f62d

File tree

4 files changed

+67
-21
lines changed

4 files changed

+67
-21
lines changed

vllm/engine/arg_utils.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
2727
from vllm.transformers_utils.utils import check_gguf_file
2828
from vllm.usage.usage_lib import UsageContext
29-
from vllm.utils import FlexibleArgumentParser, StoreBoolean
29+
from vllm.utils import FlexibleArgumentParser, StoreBoolean, is_in_ray_actor
3030

3131
if TYPE_CHECKING:
3232
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
@@ -1245,6 +1245,18 @@ def create_engine_config(
12451245
cpu_offload_gb=self.cpu_offload_gb,
12461246
calculate_kv_scales=self.calculate_kv_scales,
12471247
)
1248+
1249+
# Get the current placement group if Ray is initialized and
1250+
# we are in a Ray actor. If so, then the placement group will be
1251+
# passed to spawned processes.
1252+
placement_group = None
1253+
if is_in_ray_actor():
1254+
import ray
1255+
1256+
# This call initializes Ray automatically if it is not initialized,
1257+
# but we should not do this here.
1258+
placement_group = ray.util.get_current_placement_group()
1259+
12481260
parallel_config = ParallelConfig(
12491261
pipeline_parallel_size=self.pipeline_parallel_size,
12501262
tensor_parallel_size=self.tensor_parallel_size,
@@ -1257,6 +1269,7 @@ def create_engine_config(
12571269
self.tokenizer_pool_extra_config,
12581270
),
12591271
ray_workers_use_nsight=self.ray_workers_use_nsight,
1272+
placement_group=placement_group,
12601273
distributed_executor_backend=self.distributed_executor_backend,
12611274
worker_cls=self.worker_cls,
12621275
worker_extension_cls=self.worker_extension_cls,

vllm/executor/multiproc_worker_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from vllm.config import VllmConfig
1818
from vllm.logger import init_logger
19-
from vllm.utils import _check_multiproc_method, get_mp_context, run_method
19+
from vllm.utils import _maybe_force_spawn, get_mp_context, run_method
2020

2121
logger = init_logger(__name__)
2222

@@ -291,7 +291,7 @@ def set_multiprocessing_worker_envs(parallel_config):
291291
in a multiprocessing environment. This should be called by the parent
292292
process before worker processes are created"""
293293

294-
_check_multiproc_method()
294+
_maybe_force_spawn()
295295

296296
# Configure thread parallelism if OMP_NUM_THREADS isn't set
297297
#

vllm/executor/ray_utils.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -284,8 +284,9 @@ def initialize_ray_cluster(
284284
assert_ray_available()
285285
from vllm.platforms import current_platform
286286

287-
# Connect to a ray cluster.
288-
if current_platform.is_rocm() or current_platform.is_xpu():
287+
if ray.is_initialized():
288+
logger.info("Ray is already initialized. Skipping Ray initialization.")
289+
elif current_platform.is_rocm() or current_platform.is_xpu():
289290
# Try to connect existing ray instance and create a new one if not found
290291
try:
291292
ray.init("auto", ignore_reinit_error=True)
@@ -299,19 +300,21 @@ def initialize_ray_cluster(
299300
else:
300301
ray.init(address=ray_address, ignore_reinit_error=True)
301302

302-
if parallel_config.placement_group:
303-
# Placement group is already set.
304-
return
305-
306303
device_str = current_platform.ray_device_key
307304
if not device_str:
308305
raise ValueError(
309306
f"current platform {current_platform.device_name} does not "
310307
"support ray.")
311308

312-
# Create placement group for worker processes
313-
current_placement_group = ray.util.get_current_placement_group()
309+
# Create or get the placement group for worker processes
310+
if parallel_config.placement_group:
311+
current_placement_group = parallel_config.placement_group
312+
else:
313+
current_placement_group = ray.util.get_current_placement_group()
314+
314315
if current_placement_group:
316+
logger.info("Using the existing placement group")
317+
315318
# We are in a placement group
316319
bundles = current_placement_group.bundle_specs
317320
# Verify that we can use the placement group.
@@ -331,6 +334,8 @@ def initialize_ray_cluster(
331334
f"Required number of devices: {parallel_config.world_size}. "
332335
f"Total number of devices: {device_bundles}.")
333336
else:
337+
logger.info("No current placement group found. "
338+
"Creating a new placement group.")
334339
num_devices_in_cluster = ray.cluster_resources().get(device_str, 0)
335340
# Log a warning message and delay resource allocation failure response.
336341
# Avoid immediate rejection to allow user-initiated placement group

vllm/utils.py

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2147,20 +2147,48 @@ def zmq_socket_ctx(path: str, socket_type: Any) -> Iterator[zmq.Socket]:
21472147
ctx.destroy(linger=0)
21482148

21492149

2150-
def _check_multiproc_method():
2151-
if (cuda_is_initialized()
2152-
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"):
2153-
logger.warning("CUDA was previously initialized. We must use "
2154-
"the `spawn` multiprocessing start method. Setting "
2155-
"VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. "
2156-
"See https://docs.vllm.ai/en/latest/getting_started/"
2157-
"troubleshooting.html#python-multiprocessing "
2158-
"for more information.")
2150+
def is_in_ray_actor():
2151+
"""Check if we are in a Ray actor."""
2152+
2153+
try:
2154+
import ray
2155+
return (ray.is_initialized()
2156+
and ray.get_runtime_context().get_actor_id() is not None)
2157+
except ImportError:
2158+
return False
2159+
2160+
2161+
def _maybe_force_spawn():
2162+
"""Check if we need to force the use of the `spawn` multiprocessing start
2163+
method.
2164+
"""
2165+
if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") == "spawn":
2166+
return
2167+
2168+
reason = None
2169+
if cuda_is_initialized():
2170+
reason = "CUDA is initialized"
2171+
elif is_in_ray_actor():
2172+
reason = "In a Ray actor and can only be spawned"
2173+
2174+
if reason is not None:
2175+
logger.warning(
2176+
"We must use the `spawn` multiprocessing start method. "
2177+
"Overriding VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. "
2178+
"See https://docs.vllm.ai/en/latest/getting_started/"
2179+
"troubleshooting.html#python-multiprocessing "
2180+
"for more information. Reason: %s", reason)
21592181
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
21602182

21612183

21622184
def get_mp_context():
2163-
_check_multiproc_method()
2185+
"""Get a multiprocessing context with a particular method (spawn or fork).
2186+
By default we follow the value of the VLLM_WORKER_MULTIPROC_METHOD to
2187+
determine the multiprocessing method (default is fork). However, under
2188+
certain conditions, we may enforce spawn and override the value of
2189+
VLLM_WORKER_MULTIPROC_METHOD.
2190+
"""
2191+
_maybe_force_spawn()
21642192
mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
21652193
return multiprocessing.get_context(mp_method)
21662194

0 commit comments

Comments
 (0)