|
1 | 1 | # SPDX-License-Identifier: Apache-2.0
|
2 | 2 |
|
| 3 | +import json |
3 | 4 | import os
|
4 | 5 | import time
|
5 | 6 | from collections import defaultdict
|
6 | 7 | from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
7 | 8 |
|
8 | 9 | import msgspec
|
9 | 10 |
|
| 11 | +import vllm.envs as envs |
10 | 12 | import vllm.platforms
|
11 | 13 | from vllm.config import ParallelConfig
|
12 | 14 | from vllm.executor.msgspec_utils import decode_hook, encode_hook
|
@@ -162,6 +164,41 @@ def assert_ray_available():
|
162 | 164 | "`pip install ray`.") from ray_import_err
|
163 | 165 |
|
164 | 166 |
|
| 167 | +def serialize_placement_group_to_str(placement_group: "PlacementGroup") -> str: |
| 168 | + """Serialize a placement group to a string. |
| 169 | + FIXME: This should be implemented in Ray. |
| 170 | +
|
| 171 | + Args: |
| 172 | + placement_group: The placement group to serialize. |
| 173 | +
|
| 174 | + Returns: |
| 175 | + A string representation of the placement group. |
| 176 | + """ |
| 177 | + placement_group_data = { |
| 178 | + "id": placement_group.id.hex(), |
| 179 | + "bundle_cache": placement_group.bundle_cache, |
| 180 | + } |
| 181 | + return json.dumps(placement_group_data) |
| 182 | + |
| 183 | + |
| 184 | +def deserialize_placement_group_from_str( |
| 185 | + placement_group_str: str) -> "PlacementGroup": |
| 186 | + """Deserialize a placement group from a string. |
| 187 | + FIXME: This should be implemented in Ray. |
| 188 | +
|
| 189 | + Args: |
| 190 | + placement_group_str: The string representation of the placement group. |
| 191 | +
|
| 192 | + Returns: |
| 193 | + A placement group. |
| 194 | + """ |
| 195 | + placement_group_data = json.loads(placement_group_str) |
| 196 | + return PlacementGroup( |
| 197 | + id=ray._raylet.PlacementGroupID.from_hex(placement_group_data["id"]), |
| 198 | + bundle_cache=placement_group_data["bundle_cache"], |
| 199 | + ) |
| 200 | + |
| 201 | + |
165 | 202 | def _verify_bundles(placement_group: "PlacementGroup",
|
166 | 203 | parallel_config: ParallelConfig, device_str: str):
|
167 | 204 | """Verify a given placement group has bundles located in the right place.
|
@@ -308,12 +345,19 @@ def initialize_ray_cluster(
|
308 | 345 |
|
309 | 346 | # Create or get the placement group for worker processes
|
310 | 347 | if parallel_config.placement_group:
|
| 348 | + logger.info( |
| 349 | + "Using the existing Ray placement group from parallel config") |
311 | 350 | current_placement_group = parallel_config.placement_group
|
| 351 | + elif envs.VLLM_RAY_PLACEMENT_GROUP: |
| 352 | + logger.info("Using the existing Ray placement group from " |
| 353 | + "VLLM_RAY_PLACEMENT_GROUP") |
| 354 | + current_placement_group = deserialize_placement_group_from_str( |
| 355 | + envs.VLLM_RAY_PLACEMENT_GROUP) |
312 | 356 | else:
|
| 357 | + logger.info("Trying to get the existing Ray placement group") |
313 | 358 | current_placement_group = ray.util.get_current_placement_group()
|
314 | 359 |
|
315 | 360 | if current_placement_group:
|
316 |
| - logger.info("Using the existing placement group") |
317 | 361 |
|
318 | 362 | # We are in a placement group
|
319 | 363 | bundles = current_placement_group.bundle_specs
|
|
0 commit comments