|
| 1 | +import os |
| 2 | +from collections import defaultdict |
| 3 | +from itertools import islice, repeat |
| 4 | +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple |
| 5 | + |
| 6 | +import vllm.envs as envs |
| 7 | +from vllm.config import VllmConfig |
| 8 | +from vllm.logger import init_logger |
| 9 | +from vllm.utils import get_distributed_init_method, get_ip, get_open_port |
| 10 | +from vllm.v1.executor.abstract import Executor |
| 11 | +from vllm.v1.executor.ray_utils import RayWorkerWrapper, ray |
| 12 | +from vllm.v1.outputs import ModelRunnerOutput |
| 13 | + |
| 14 | +if ray is not None: |
| 15 | + from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy |
| 16 | + |
| 17 | +if TYPE_CHECKING: |
| 18 | + from ray.util.placement_group import PlacementGroup |
| 19 | + |
| 20 | +logger = init_logger(__name__) |
| 21 | + |
| 22 | + |
| 23 | +class RayExecutor(Executor): |
| 24 | + |
| 25 | + def __init__(self, vllm_config: VllmConfig) -> None: |
| 26 | + self.vllm_config = vllm_config |
| 27 | + self.parallel_config = vllm_config.parallel_config |
| 28 | + self.model_config = vllm_config.model_config |
| 29 | + self.forward_dag: Optional[ray.dag.CompiledDAG] = None |
| 30 | + |
| 31 | + # Disable Ray usage stats collection. |
| 32 | + ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0") |
| 33 | + if ray_usage != "1": |
| 34 | + os.environ["RAY_USAGE_STATS_ENABLED"] = "0" |
| 35 | + |
| 36 | + placement_group = self.parallel_config.placement_group |
| 37 | + # Create the parallel GPU workers. |
| 38 | + self._init_workers_ray(placement_group) |
| 39 | + |
| 40 | + def _init_workers_ray(self, placement_group: "PlacementGroup", |
| 41 | + **ray_remote_kwargs): |
| 42 | + # A list of workers to run a model. |
| 43 | + self.workers: List[RayWorkerWrapper] = [] |
| 44 | + if self.parallel_config.ray_workers_use_nsight: |
| 45 | + ray_remote_kwargs = self._configure_ray_workers_use_nsight( |
| 46 | + ray_remote_kwargs) |
| 47 | + |
| 48 | + # Create the workers. |
| 49 | + driver_ip = get_ip() |
| 50 | + for bundle_id, bundle in enumerate(placement_group.bundle_specs): |
| 51 | + if not bundle.get("GPU", 0): |
| 52 | + # Skip bundles that don't have GPUs, |
| 53 | + # as each worker needs one GPU. |
| 54 | + continue |
| 55 | + scheduling_strategy = PlacementGroupSchedulingStrategy( |
| 56 | + placement_group=placement_group, |
| 57 | + placement_group_capture_child_tasks=True, |
| 58 | + placement_group_bundle_index=bundle_id, |
| 59 | + ) |
| 60 | + |
| 61 | + worker = ray.remote( |
| 62 | + num_cpus=0, |
| 63 | + num_gpus=1, |
| 64 | + scheduling_strategy=scheduling_strategy, |
| 65 | + **ray_remote_kwargs, |
| 66 | + )(RayWorkerWrapper).remote(vllm_config=self.vllm_config) |
| 67 | + self.workers.append(worker) |
| 68 | + |
| 69 | + logger.debug("workers: %s", self.workers) |
| 70 | + worker_ips = [ |
| 71 | + ray.get(worker.get_node_ip.remote()) # type: ignore[attr-defined] |
| 72 | + for worker in self.workers |
| 73 | + ] |
| 74 | + ip_counts: Dict[str, int] = {} |
| 75 | + for ip in worker_ips: |
| 76 | + ip_counts[ip] = ip_counts.get(ip, 0) + 1 |
| 77 | + |
| 78 | + worker_to_ip = dict(zip(self.workers, worker_ips)) |
| 79 | + |
| 80 | + def sort_by_driver_then_worker_ip(worker): |
| 81 | + """ |
| 82 | + Sort the workers based on 3 properties: |
| 83 | + 1. If the worker is on the same node as the driver (vllm engine), |
| 84 | + it should be placed first. |
| 85 | + 2. Then, if the worker is on a node with fewer workers, it should |
| 86 | + be placed first. |
| 87 | + 3. Finally, if the work is on a node with smaller IP address, it |
| 88 | + should be placed first. This is simply a tiebreaker to make |
| 89 | + sure the workers are sorted in a deterministic way. |
| 90 | + """ |
| 91 | + ip = worker_to_ip[worker] |
| 92 | + return (ip != driver_ip, ip_counts[ip], ip) |
| 93 | + |
| 94 | + # After sorting, the workers on the same node will be |
| 95 | + # close to each other, and the workers on the driver |
| 96 | + # node will be placed first. |
| 97 | + self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip) |
| 98 | + |
| 99 | + # Get the set of GPU IDs used on each node. |
| 100 | + worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids") |
| 101 | + |
| 102 | + node_workers = defaultdict(list) # node id -> list of worker ranks |
| 103 | + node_gpus = defaultdict(list) # node id -> list of gpu ids |
| 104 | + |
| 105 | + for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids): |
| 106 | + node_workers[node_id].append(i) |
| 107 | + # `gpu_ids` can be a list of strings or integers. |
| 108 | + # convert them to integers for consistency. |
| 109 | + # NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs), |
| 110 | + # string sorting is not sufficient. |
| 111 | + # see https://github.com/vllm-project/vllm/issues/5590 |
| 112 | + gpu_ids = [int(x) for x in gpu_ids] |
| 113 | + node_gpus[node_id].extend(gpu_ids) |
| 114 | + |
| 115 | + for node_id, gpu_ids in node_gpus.items(): |
| 116 | + node_gpus[node_id] = sorted(gpu_ids) |
| 117 | + |
| 118 | + all_ips = set(worker_ips) |
| 119 | + n_ips = len(all_ips) |
| 120 | + n_nodes = len(node_workers) |
| 121 | + |
| 122 | + if n_nodes != n_ips: |
| 123 | + raise RuntimeError( |
| 124 | + f"Every node should have a unique IP address. Got {n_nodes}" |
| 125 | + f" nodes with node ids {list(node_workers.keys())} and " |
| 126 | + f"{n_ips} unique IP addresses {all_ips}. Please check your" |
| 127 | + " network configuration. If you set `VLLM_HOST_IP` or " |
| 128 | + "`HOST_IP` environment variable, make sure it is unique for" |
| 129 | + " each node.") |
| 130 | + |
| 131 | + # Set environment variables for the driver and workers. |
| 132 | + all_args_to_update_environment_variables = [({ |
| 133 | + "CUDA_VISIBLE_DEVICES": |
| 134 | + ",".join(map(str, node_gpus[node_id])), |
| 135 | + "VLLM_TRACE_FUNCTION": |
| 136 | + str(envs.VLLM_TRACE_FUNCTION), |
| 137 | + "VLLM_USE_V1": |
| 138 | + str(int(envs.VLLM_USE_V1)), |
| 139 | + **({ |
| 140 | + "VLLM_ATTENTION_BACKEND": envs.VLLM_ATTENTION_BACKEND |
| 141 | + } if envs.VLLM_ATTENTION_BACKEND is not None else {}) |
| 142 | + }, ) for (node_id, _) in worker_node_and_gpu_ids] |
| 143 | + |
| 144 | + self._env_vars_for_all_workers = ( |
| 145 | + all_args_to_update_environment_variables) |
| 146 | + |
| 147 | + self._run_workers("update_environment_variables", |
| 148 | + all_args=self._get_env_vars_to_be_updated()) |
| 149 | + |
| 150 | + if len(node_gpus) == 1: |
| 151 | + # in single node case, we don't need to get the IP address. |
| 152 | + # the loopback address is sufficient |
| 153 | + # NOTE: a node may have several IP addresses, one for each |
| 154 | + # network interface. `get_ip()` might return any of them, |
| 155 | + # while they might not work for communication inside the node |
| 156 | + # if the network setup is complicated. Using the loopback address |
| 157 | + # solves this issue, as it always works for communication inside |
| 158 | + # the node. |
| 159 | + driver_ip = "127.0.0.1" |
| 160 | + distributed_init_method = get_distributed_init_method( |
| 161 | + driver_ip, get_open_port()) |
| 162 | + |
| 163 | + # Initialize the actual workers inside worker wrapper. |
| 164 | + init_worker_all_kwargs = [ |
| 165 | + self._get_worker_kwargs( |
| 166 | + local_rank=node_workers[node_id].index(rank), |
| 167 | + rank=rank, |
| 168 | + distributed_init_method=distributed_init_method, |
| 169 | + ) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids) |
| 170 | + ] |
| 171 | + self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs) |
| 172 | + self._run_workers("initialize") |
| 173 | + self._run_workers("load_model") |
| 174 | + |
| 175 | + def _configure_ray_workers_use_nsight(self, |
| 176 | + ray_remote_kwargs) -> Dict[str, Any]: |
| 177 | + # If nsight profiling is enabled, we need to set the profiling |
| 178 | + # configuration for the ray workers as runtime env. |
| 179 | + runtime_env = ray_remote_kwargs.setdefault("runtime_env", {}) |
| 180 | + runtime_env.update({ |
| 181 | + "nsight": { |
| 182 | + "t": "cuda,cudnn,cublas", |
| 183 | + "o": "'worker_process_%p'", |
| 184 | + "cuda-graph-trace": "node", |
| 185 | + } |
| 186 | + }) |
| 187 | + |
| 188 | + return ray_remote_kwargs |
| 189 | + |
| 190 | + def _get_env_vars_to_be_updated(self): |
| 191 | + return self._env_vars_for_all_workers |
| 192 | + |
| 193 | + def _get_worker_kwargs( |
| 194 | + self, |
| 195 | + local_rank: int = 0, |
| 196 | + rank: int = 0, |
| 197 | + distributed_init_method: Optional[str] = None) -> Dict[str, Any]: |
| 198 | + """ |
| 199 | + Return worker init args for a given rank. |
| 200 | + """ |
| 201 | + if distributed_init_method is None: |
| 202 | + distributed_init_method = get_distributed_init_method( |
| 203 | + get_ip(), get_open_port()) |
| 204 | + return dict( |
| 205 | + vllm_config=self.vllm_config, |
| 206 | + local_rank=local_rank, |
| 207 | + rank=rank, |
| 208 | + distributed_init_method=distributed_init_method, |
| 209 | + ) |
| 210 | + |
| 211 | + def determine_num_available_blocks(self) -> Tuple[int, int]: |
| 212 | + """ |
| 213 | + Determine the number of available KV blocks. |
| 214 | + |
| 215 | + This invokes `determine_num_available_blocks` on each worker and takes |
| 216 | + the min of the results, guaranteeing that the selected cache sizes are |
| 217 | + compatible with all workers. |
| 218 | + |
| 219 | + Returns: |
| 220 | + - tuple[num_gpu_blocks, num_cpu_blocks] |
| 221 | + """ |
| 222 | + # Get the maximum number of blocks that can be allocated on GPU and CPU. |
| 223 | + num_blocks = self._run_workers("determine_num_available_blocks") |
| 224 | + |
| 225 | + # Since we use a shared centralized controller, we take the minimum |
| 226 | + # number of blocks across all workers to make sure all the memory |
| 227 | + # operators can be applied to all workers. |
| 228 | + num_gpu_blocks = min(b[0] for b in num_blocks) |
| 229 | + num_cpu_blocks = min(b[1] for b in num_blocks) |
| 230 | + |
| 231 | + return num_gpu_blocks, num_cpu_blocks |
| 232 | + |
| 233 | + def initialize(self, num_gpu_blocks: int) -> None: |
| 234 | + """ |
| 235 | + Initialize the KV cache in all workers. |
| 236 | + """ |
| 237 | + # NOTE: This is logged in the executor because there can be >1 worker |
| 238 | + # with other executors. We could log in the engine level, but work |
| 239 | + # remains to abstract away the device for non-GPU configurations. |
| 240 | + logger.info("# GPU blocks: %d", num_gpu_blocks) |
| 241 | + self._run_workers("initialize_cache", num_gpu_blocks) |
| 242 | + self._run_workers("compile_or_warm_up_model") |
| 243 | + |
| 244 | + def _run_workers( |
| 245 | + self, |
| 246 | + method: str, |
| 247 | + *args, |
| 248 | + all_args: Optional[List[Tuple[Any, ...]]] = None, |
| 249 | + all_kwargs: Optional[List[Dict[str, Any]]] = None, |
| 250 | + **kwargs, |
| 251 | + ) -> Any: |
| 252 | + """ |
| 253 | + Runs the given method on all workers. Can be used in the following |
| 254 | + ways: |
| 255 | +
|
| 256 | + Args: |
| 257 | + - args/kwargs: All workers share the same args/kwargs |
| 258 | + - all_args/all_kwargs: args/kwargs for each worker are specified |
| 259 | + individually |
| 260 | + """ |
| 261 | + count = len(self.workers) |
| 262 | + all_worker_args = repeat(args, count) if all_args is None \ |
| 263 | + else islice(all_args, 0, None) |
| 264 | + all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \ |
| 265 | + else islice(all_kwargs, 0, None) |
| 266 | + |
| 267 | + ray_worker_refs = [ |
| 268 | + worker.execute_method.remote( # type: ignore[attr-defined] |
| 269 | + method, *worker_args, **worker_kwargs) |
| 270 | + for (worker, worker_args, worker_kwargs |
| 271 | + ) in zip(self.workers, all_worker_args, all_worker_kwargs) |
| 272 | + ] |
| 273 | + return ray.get(ray_worker_refs) |
| 274 | + |
| 275 | + def execute_model( |
| 276 | + self, |
| 277 | + scheduler_output, |
| 278 | + ) -> ModelRunnerOutput: |
| 279 | + if self.forward_dag is None: |
| 280 | + self.forward_dag = self._compiled_ray_dag() |
| 281 | + # Only the first worker (with rank 0) returns the execution result. |
| 282 | + # Others return None. |
| 283 | + output = ray.get(self.forward_dag.execute(scheduler_output))[0] |
| 284 | + return output |
| 285 | + |
| 286 | + def profile(self, is_start=True): |
| 287 | + raise NotImplementedError |
| 288 | + |
| 289 | + def shutdown(self): |
| 290 | + if hasattr(self, "forward_dag") and self.forward_dag is not None: |
| 291 | + self.forward_dag.teardown() |
| 292 | + import ray |
| 293 | + for worker in self.workers: |
| 294 | + ray.kill(worker) |
| 295 | + self.forward_dag = None |
| 296 | + |
| 297 | + def check_health(self) -> None: |
| 298 | + logger.debug("Called check_health.") |
| 299 | + |
| 300 | + def _check_ray_compiled_graph_installation(self): |
| 301 | + import pkg_resources |
| 302 | + from packaging import version |
| 303 | + |
| 304 | + required_version = version.parse("2.39") |
| 305 | + current_version = version.parse( |
| 306 | + pkg_resources.get_distribution("ray").version) |
| 307 | + if current_version < required_version: |
| 308 | + raise ValueError(f"Ray version {required_version} is " |
| 309 | + f"required, but found {current_version}") |
| 310 | + |
| 311 | + import importlib.util |
| 312 | + raycg = importlib.util.find_spec("ray.experimental.compiled_dag_ref") |
| 313 | + if raycg is None: |
| 314 | + raise ValueError("Ray Compiled Graph is not installed. " |
| 315 | + "Run `pip install ray[adag]` to install it.") |
| 316 | + |
| 317 | + cupy_spec = importlib.util.find_spec("cupy") |
| 318 | + if cupy_spec is None and envs.VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL: |
| 319 | + raise ValueError( |
| 320 | + "cupy is not installed but required since " |
| 321 | + "VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL is set." |
| 322 | + "Run `pip install ray[adag]` and check cupy installation.") |
| 323 | + |
| 324 | + def _compiled_ray_dag(self): |
| 325 | + assert self.parallel_config.use_ray |
| 326 | + self._check_ray_compiled_graph_installation() |
| 327 | + from ray.dag import InputNode, MultiOutputNode |
| 328 | + |
| 329 | + with InputNode() as input_batches: |
| 330 | + outputs = [ |
| 331 | + worker.execute_model.bind( # type: ignore[attr-defined] |
| 332 | + input_batches) for worker in self.workers |
| 333 | + ] |
| 334 | + forward_dag = MultiOutputNode(outputs) |
| 335 | + |
| 336 | + return forward_dag.experimental_compile() |
| 337 | + |
| 338 | + def __del__(self): |
| 339 | + self.shutdown() |
0 commit comments