Skip to content

Commit 6907571

Browse files
committed
fix
Signed-off-by: Cody Yu <[email protected]>
1 parent 9690e43 commit 6907571

File tree

3 files changed

+55
-3
lines changed

3 files changed

+55
-3
lines changed

examples/offline_inference/rlhf.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,9 @@ def __init__(self, *args, **kwargs):
4444
train_model.to("cuda:0")
4545
"""
4646
Start the inference process, here we use vLLM to hold a model on GPU 1 and
47-
GPU 2. For the details on how to use ray, please refer to the ray
48-
documentation https://docs.ray.io/en/latest/ .
47+
GPU 2 by creating a Ray placement group. The placement group will be passed
48+
to the worker processes spawned by vLLM. For the details on how to use Ray,
49+
please refer to the Ray documentation https://docs.ray.io/en/latest/ .
4950
"""
5051
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2"
5152
ray.init()

vllm/envs.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@
8686
VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: bool = False
8787
VLLM_RAY_PER_WORKER_GPUS: float = 1.0
8888
VLLM_RAY_BUNDLE_INDICES: str = ""
89+
VLLM_RAY_PLACEMENT_GROUP: Optional[str] = None
8990
VLLM_CUDART_SO_PATH: Optional[str] = None
9091
VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH: bool = True
9192
VLLM_DP_RANK: int = 0
@@ -577,6 +578,12 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
577578
"VLLM_RAY_BUNDLE_INDICES":
578579
lambda: os.getenv("VLLM_RAY_BUNDLE_INDICES", ""),
579580

581+
# Ray placement group (serialized string), if it is set and
582+
# ray.util.get_current_placement_group() is None, it will be used as the
583+
# placement group in vLLM Ray executor.
584+
"VLLM_RAY_PLACEMENT_GROUP":
585+
lambda: os.getenv("VLLM_RAY_PLACEMENT_GROUP", None),
586+
580587
# In some system, find_loaded_library() may not work. So we allow users to
581588
# specify the path through environment variable VLLM_CUDART_SO_PATH.
582589
"VLLM_CUDART_SO_PATH":

vllm/executor/ray_utils.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3+
import json
34
import os
45
import time
56
from collections import defaultdict
67
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
78

89
import msgspec
910

11+
import vllm.envs as envs
1012
import vllm.platforms
1113
from vllm.config import ParallelConfig
1214
from vllm.executor.msgspec_utils import decode_hook, encode_hook
@@ -162,6 +164,41 @@ def assert_ray_available():
162164
"`pip install ray`.") from ray_import_err
163165

164166

167+
def serialize_placement_group_to_str(placement_group: "PlacementGroup") -> str:
168+
"""Serialize a placement group to a string.
169+
FIXME: This should be implemented in Ray.
170+
171+
Args:
172+
placement_group: The placement group to serialize.
173+
174+
Returns:
175+
A string representation of the placement group.
176+
"""
177+
placement_group_data = {
178+
"id": placement_group.id.hex(),
179+
"bundle_cache": placement_group.bundle_cache,
180+
}
181+
return json.dumps(placement_group_data)
182+
183+
184+
def deserialize_placement_group_from_str(
185+
placement_group_str: str) -> "PlacementGroup":
186+
"""Deserialize a placement group from a string.
187+
FIXME: This should be implemented in Ray.
188+
189+
Args:
190+
placement_group_str: The string representation of the placement group.
191+
192+
Returns:
193+
A placement group.
194+
"""
195+
placement_group_data = json.loads(placement_group_str)
196+
return PlacementGroup(
197+
id=ray._raylet.PlacementGroupID.from_hex(placement_group_data["id"]),
198+
bundle_cache=placement_group_data["bundle_cache"],
199+
)
200+
201+
165202
def _verify_bundles(placement_group: "PlacementGroup",
166203
parallel_config: ParallelConfig, device_str: str):
167204
"""Verify a given placement group has bundles located in the right place.
@@ -308,12 +345,19 @@ def initialize_ray_cluster(
308345

309346
# Create or get the placement group for worker processes
310347
if parallel_config.placement_group:
348+
logger.info(
349+
"Using the existing Ray placement group from parallel config")
311350
current_placement_group = parallel_config.placement_group
351+
elif envs.VLLM_RAY_PLACEMENT_GROUP:
352+
logger.info("Using the existing Ray placement group from "
353+
"VLLM_RAY_PLACEMENT_GROUP")
354+
current_placement_group = deserialize_placement_group_from_str(
355+
envs.VLLM_RAY_PLACEMENT_GROUP)
312356
else:
357+
logger.info("Trying to get the existing Ray placement group")
313358
current_placement_group = ray.util.get_current_placement_group()
314359

315360
if current_placement_group:
316-
logger.info("Using the existing placement group")
317361

318362
# We are in a placement group
319363
bundles = current_placement_group.bundle_specs

0 commit comments

Comments
 (0)