@@ -266,44 +266,104 @@ def __init__(
266
266
executor_class : type [Executor ],
267
267
log_stats : bool ,
268
268
):
269
+ import copy
270
+
269
271
import ray
272
+ from ray ._private .state import available_resources_per_node
273
+ from ray .util .scheduling_strategies import (
274
+ PlacementGroupSchedulingStrategy )
275
+ from ray .util .state import list_nodes
270
276
271
277
from vllm .v1 .engine .core import DPEngineCoreActor
272
278
273
279
self .local_engine_actors : list [ray .ActorHandle ] = []
274
280
self .remote_engine_actors : list [ray .ActorHandle ] = []
275
281
276
- # TODO(rui): use proper placement strategy to put engine actors
277
- # on the desired nodes.
282
+ dp_size = vllm_config .parallel_config .data_parallel_size
283
+ remote_engine_count = dp_size - local_engine_count
284
+
285
+ if ray .is_initialized ():
286
+ logger .info (
287
+ "Ray is already initialized. Skipping Ray initialization." )
288
+ else :
289
+ ray .init ()
290
+
291
+ nodes = list_nodes ()
292
+ available_resources_by_id = available_resources_per_node ()
293
+ available_resources_by_ip = {}
294
+ num_workers = vllm_config .parallel_config .world_size
295
+
296
+ dp_size_available = 0
297
+ for node in nodes :
298
+ node_ip = node .node_ip
299
+ node_id = node .node_id
300
+ node_resources = available_resources_by_id [node_id ]
301
+ available_resources_by_ip [node_ip ] = node_resources
302
+ # For now, each DP rank can only be assigned to one node
303
+ # TODO(rui): support allocating a single DP rank to multiple nodes
304
+ dp_size_available += node_resources ["GPU" ] // num_workers
305
+
306
+ assert dp_size_available >= dp_size , (
307
+ "Not enough resources to allocate DP ranks" )
308
+
309
+ head_node_ip = \
310
+ vllm_config .parallel_config .data_parallel_master_ip
311
+
278
312
refs = []
279
313
for index in range (local_engine_count ):
280
314
local_index = local_start_index + index
281
315
global_index = start_index + index
282
- actor = ray .remote (DPEngineCoreActor ).remote (
283
- vllm_config = vllm_config ,
284
- executor_class = executor_class ,
285
- log_stats = log_stats ,
286
- addresses = addresses ,
287
- on_head_node = True ,
288
- engine_index = global_index ,
289
- dp_rank = global_index ,
290
- local_dp_rank = local_index )
316
+ dp_vllm_config = copy .deepcopy (vllm_config )
317
+ bundles = [{
318
+ "GPU" : 1.0 ,
319
+ "node:" + head_node_ip : 0.001
320
+ }] * num_workers + [{
321
+ "CPU" : 1.0
322
+ }]
323
+ pg = ray .util .placement_group (
324
+ name = f"dp_rank_{ global_index } " ,
325
+ strategy = "STRICT_PACK" ,
326
+ bundles = bundles ,
327
+ )
328
+ dp_vllm_config .parallel_config .placement_group = pg
329
+ actor = ray .remote (DPEngineCoreActor ).options (
330
+ scheduling_strategy = PlacementGroupSchedulingStrategy (
331
+ placement_group = pg ,
332
+ placement_group_bundle_index = num_workers ,
333
+ )).remote (vllm_config = dp_vllm_config ,
334
+ executor_class = executor_class ,
335
+ log_stats = log_stats ,
336
+ addresses = addresses ,
337
+ on_head_node = True ,
338
+ engine_index = global_index ,
339
+ dp_rank = global_index ,
340
+ local_dp_rank = local_index )
291
341
self .local_engine_actors .append (actor )
292
342
refs .append (actor .wait_for_init .remote ())
293
343
294
- dp_size = vllm_config .parallel_config .data_parallel_size
295
- for index in range (dp_size - local_engine_count ):
344
+ for index in range (remote_engine_count ):
296
345
local_index = index
297
346
global_index = local_engine_count + index
298
- actor = ray .remote (DPEngineCoreActor ).remote (
299
- vllm_config = vllm_config ,
300
- executor_class = executor_class ,
301
- log_stats = log_stats ,
302
- addresses = addresses ,
303
- on_head_node = False ,
304
- engine_index = global_index ,
305
- dp_rank = global_index ,
306
- local_dp_rank = local_index )
347
+ bundles = [{"GPU" : 1.0 }] * num_workers + [{"CPU" : 1.0 }]
348
+ pg = ray .util .placement_group (
349
+ name = f"dp_rank_{ global_index } " ,
350
+ strategy = "STRICT_PACK" ,
351
+ bundles = bundles ,
352
+ )
353
+ dp_vllm_config = copy .deepcopy (vllm_config )
354
+ dp_vllm_config .parallel_config .placement_group = pg
355
+ actor = ray .remote (DPEngineCoreActor ).options (
356
+ scheduling_strategy = PlacementGroupSchedulingStrategy (
357
+ placement_group = pg ,
358
+ placement_group_bundle_index = num_workers ,
359
+ )).remote (vllm_config = dp_vllm_config ,
360
+ executor_class = executor_class ,
361
+ log_stats = log_stats ,
362
+ addresses = addresses ,
363
+ on_head_node = False ,
364
+ engine_index = global_index ,
365
+ dp_rank = global_index ,
366
+ local_dp_rank = local_index )
307
367
self .remote_engine_actors .append (actor )
308
368
refs .append (actor .wait_for_init .remote ())
309
369
0 commit comments