|
3 | 3 | from vllm.config import ParallelConfig
|
4 | 4 | from vllm.logger import init_logger
|
5 | 5 | from vllm.sequence import ExecuteModelRequest
|
6 |
| -from vllm.utils import get_ip, is_hip, is_xpu |
| 6 | +from vllm.utils import get_ip, is_hip, is_tpu, is_xpu |
7 | 7 | from vllm.worker.worker_base import WorkerWrapperBase
|
8 | 8 |
|
9 | 9 | logger = init_logger(__name__)
|
@@ -93,32 +93,38 @@ def initialize_ray_cluster(
|
93 | 93 | # Placement group is already set.
|
94 | 94 | return
|
95 | 95 |
|
| 96 | + device_str = "GPU" if not is_tpu() else "TPU" |
96 | 97 | # Create placement group for worker processes
|
97 | 98 | current_placement_group = ray.util.get_current_placement_group()
|
98 | 99 | if current_placement_group:
|
99 | 100 | # We are in a placement group
|
100 | 101 | bundles = current_placement_group.bundle_specs
|
101 | 102 | # Verify that we can use the placement group.
|
102 |
| - gpu_bundles = 0 |
| 103 | + device_bundles = 0 |
103 | 104 | for bundle in bundles:
|
104 |
| - bundle_gpus = bundle.get("GPU", 0) |
105 |
| - if bundle_gpus > 1: |
| 105 | + bundle_devices = bundle.get(device_str, 0) |
| 106 | + if bundle_devices > 1: |
106 | 107 | raise ValueError(
|
107 |
| - "Placement group bundle cannot have more than 1 GPU.") |
108 |
| - if bundle_gpus: |
109 |
| - gpu_bundles += 1 |
110 |
| - if parallel_config.world_size > gpu_bundles: |
| 108 | + "Placement group bundle cannot have more than 1 " |
| 109 | + f"{device_str}.") |
| 110 | + if bundle_devices: |
| 111 | + device_bundles += 1 |
| 112 | + if parallel_config.world_size > device_bundles: |
111 | 113 | raise ValueError(
|
112 |
| - "The number of required GPUs exceeds the total number of " |
113 |
| - "available GPUs in the placement group.") |
| 114 | + f"The number of required {device_str}s exceeds the total " |
| 115 | + f"number of available {device_str}s in the placement group." |
| 116 | + f"Required number of devices: {parallel_config.world_size}. " |
| 117 | + f"Total number of devices: {device_bundles}.") |
114 | 118 | else:
|
115 |
| - num_gpus_in_cluster = ray.cluster_resources().get("GPU", 0) |
116 |
| - if parallel_config.world_size > num_gpus_in_cluster: |
| 119 | + num_devices_in_cluster = ray.cluster_resources().get(device_str, 0) |
| 120 | + if parallel_config.world_size > num_devices_in_cluster: |
117 | 121 | raise ValueError(
|
118 |
| - "The number of required GPUs exceeds the total number of " |
119 |
| - "available GPUs in the cluster.") |
| 122 | + f"The number of required {device_str}s exceeds the total " |
| 123 | + f"number of available {device_str}s in the placement group.") |
120 | 124 | # Create a new placement group
|
121 |
| - placement_group_specs = ([{"GPU": 1}] * parallel_config.world_size) |
| 125 | + placement_group_specs = ([{ |
| 126 | + device_str: 1 |
| 127 | + }] * parallel_config.world_size) |
122 | 128 | current_placement_group = ray.util.placement_group(
|
123 | 129 | placement_group_specs)
|
124 | 130 | # Wait until PG is ready - this will block until all
|
|
0 commit comments