Skip to content

Commit 4b4ee56

Browse files
WoosukKwonLeiWang1999
authored andcommitted
[Misc][TPU] Support TPU in initialize_ray_cluster (vllm-project#6812)
Signed-off-by: LeiWang1999 <[email protected]>
1 parent 9119528 commit 4b4ee56

File tree

1 file changed

+21
-15
lines changed

1 file changed

+21
-15
lines changed

vllm/executor/ray_utils.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from vllm.config import ParallelConfig
44
from vllm.logger import init_logger
55
from vllm.sequence import ExecuteModelRequest
6-
from vllm.utils import get_ip, is_hip, is_xpu
6+
from vllm.utils import get_ip, is_hip, is_tpu, is_xpu
77
from vllm.worker.worker_base import WorkerWrapperBase
88

99
logger = init_logger(__name__)
@@ -93,32 +93,38 @@ def initialize_ray_cluster(
9393
# Placement group is already set.
9494
return
9595

96+
device_str = "GPU" if not is_tpu() else "TPU"
9697
# Create placement group for worker processes
9798
current_placement_group = ray.util.get_current_placement_group()
9899
if current_placement_group:
99100
# We are in a placement group
100101
bundles = current_placement_group.bundle_specs
101102
# Verify that we can use the placement group.
102-
gpu_bundles = 0
103+
device_bundles = 0
103104
for bundle in bundles:
104-
bundle_gpus = bundle.get("GPU", 0)
105-
if bundle_gpus > 1:
105+
bundle_devices = bundle.get(device_str, 0)
106+
if bundle_devices > 1:
106107
raise ValueError(
107-
"Placement group bundle cannot have more than 1 GPU.")
108-
if bundle_gpus:
109-
gpu_bundles += 1
110-
if parallel_config.world_size > gpu_bundles:
108+
"Placement group bundle cannot have more than 1 "
109+
f"{device_str}.")
110+
if bundle_devices:
111+
device_bundles += 1
112+
if parallel_config.world_size > device_bundles:
111113
raise ValueError(
112-
"The number of required GPUs exceeds the total number of "
113-
"available GPUs in the placement group.")
114+
f"The number of required {device_str}s exceeds the total "
115+
f"number of available {device_str}s in the placement group."
116+
f"Required number of devices: {parallel_config.world_size}. "
117+
f"Total number of devices: {device_bundles}.")
114118
else:
115-
num_gpus_in_cluster = ray.cluster_resources().get("GPU", 0)
116-
if parallel_config.world_size > num_gpus_in_cluster:
119+
num_devices_in_cluster = ray.cluster_resources().get(device_str, 0)
120+
if parallel_config.world_size > num_devices_in_cluster:
117121
raise ValueError(
118-
"The number of required GPUs exceeds the total number of "
119-
"available GPUs in the cluster.")
122+
f"The number of required {device_str}s exceeds the total "
123+
f"number of available {device_str}s in the placement group.")
120124
# Create a new placement group
121-
placement_group_specs = ([{"GPU": 1}] * parallel_config.world_size)
125+
placement_group_specs = ([{
126+
device_str: 1
127+
}] * parallel_config.world_size)
122128
current_placement_group = ray.util.placement_group(
123129
placement_group_specs)
124130
# Wait until PG is ready - this will block until all

0 commit comments

Comments
 (0)