1
1
# SPDX-License-Identifier: Apache-2.0
2
2
"""
3
- a simple demonstration to show how to control
4
- the placement of the vLLM workers with Ray.
5
- The key is to set VLLM_RAY_PER_WORKER_GPUS and
6
- VLLM_RAY_BUNDLE_INDICES properly.
3
+ a simple demonstration to show how to co-locate
4
+ vLLM worker with training actors on the same GPUs,
5
+ for RLHF-like applications.
6
+ The key points:
7
+ - Control the placement of the vLLM workers with Ray, by setting
8
+ VLLM_RAY_PER_WORKER_GPUS and VLLM_RAY_BUNDLE_INDICES properly.
9
+ - Use cuda-ipc to pass tensors, since NCCL does not work when we have
10
+ multiple processes on the same GPU.
7
11
"""
8
12
import os
9
13
10
14
import ray
15
+ import torch
11
16
from ray .util .placement_group import placement_group
12
17
from ray .util .scheduling_strategies import PlacementGroupSchedulingStrategy
13
18
@@ -19,7 +24,33 @@ class MyWorker(Worker):
19
24
20
25
def report_device_id (self ) -> str :
21
26
from vllm .platforms import current_platform
22
- return current_platform .get_device_uuid (self .device .index )
27
+ self .device_uuid = current_platform .get_device_uuid (self .device .index )
28
+ return self .device_uuid
29
+
30
+ def update_weights_from_ipc_handles (self , ipc_handles ):
31
+ handles = ipc_handles [self .device_uuid ]
32
+ device_id = self .device .index
33
+ weights = []
34
+ for name , handle in handles .items ():
35
+ func , args = handle
36
+ list_args = list (args )
37
+ # the key is to change device id to the current device id
38
+ # in case two processes have different CUDA_VISIBLE_DEVICES
39
+ list_args [6 ] = device_id
40
+ tensor = func (* list_args )
41
+ weights .append ((name , tensor ))
42
+ self .model_runner .model .load_weights (weights = weights )
43
+ torch .cuda .synchronize ()
44
+
45
+ def check_weights_changed (self ):
46
+ """
47
+ Check if the weights are updated to 0.
48
+ """
49
+ weights_updated = True
50
+ for name , p in self .model_runner .model .named_parameters ():
51
+ weights_updated = weights_updated and torch .allclose (
52
+ p , torch .zeros_like (p ))
53
+ return weights_updated
23
54
24
55
25
56
class MyLLM (LLM ):
@@ -40,12 +71,32 @@ def __init__(self, *args, bundle_indices: list, **kwargs):
40
71
41
72
class RayTrainingActor :
42
73
43
- def report_device_id (self ) -> str :
74
+ def __init__ (self ):
75
+ # ray will set CUDA_VISIBLE_DEVICES to the assigned GPUs
76
+ from transformers import AutoModelForCausalLM
77
+ self .model = AutoModelForCausalLM .from_pretrained ("facebook/opt-125m" )
78
+ self .model .to ("cuda:0" )
79
+ for name , p in self .model .named_parameters ():
80
+ p .data .zero_ ()
81
+ torch .cuda .synchronize ()
44
82
# the argument for get_device_uuid is the index
45
83
# of the GPU in the visible devices.
46
- # ray will set CUDA_VISIBLE_DEVICES to the assigned GPUs
47
84
from vllm .platforms import current_platform
48
- return current_platform .get_device_uuid (0 )
85
+ self .device_uuid = current_platform .get_device_uuid (0 )
86
+
87
+ def report_device_id (self ) -> str :
88
+ return self .device_uuid
89
+
90
+ def get_weight_ipc_handles (self ):
91
+ from torch .multiprocessing .reductions import reduce_tensor
92
+ data = {}
93
+ for name , p in self .model .named_parameters ():
94
+ # the training actor might only have a subset of the weights
95
+ # and need to all-gather the weights from all the actors.
96
+ # for demonstration, here we assume all training actors have
97
+ # the full weights.
98
+ data [name ] = reduce_tensor (p .detach ())
99
+ return {self .device_uuid : data }
49
100
50
101
51
102
# ray manages 4 GPUs
@@ -78,6 +129,8 @@ def report_device_id(self) -> str:
78
129
),
79
130
)(RayTrainingActor ).remote ()
80
131
training_actors .append (training_actor )
132
+
133
+ for bundle_index , training_actor in enumerate (training_actors ):
81
134
device_id = ray .get (training_actor .report_device_id .remote ())
82
135
print (f"training actor { bundle_index } is on { device_id } " )
83
136
training_actor_device_ids .append (device_id )
@@ -119,3 +172,18 @@ def report_device_id(self) -> str:
119
172
# the last two training actors should be
120
173
# on the same GPUs as the second inference engine
121
174
assert training_actor_device_ids [2 :] == inference_engine_device_ids [1 ]
175
+
176
+ print ("gather all the IPC handles from the training actors" )
177
+ ipc_handles = {}
178
+ for actor in training_actors :
179
+ ipc_handles .update (ray .get (actor .get_weight_ipc_handles .remote ()))
180
+
181
+ print ("update the weights of the inference engines" )
182
+ for llm in inference_engines :
183
+ ray .get (
184
+ llm .collective_rpc .remote ("update_weights_from_ipc_handles" ,
185
+ args = (ipc_handles , )))
186
+ print ("check if the weights are updated" )
187
+ for llm in inference_engines :
188
+ assert ray .get (
189
+ llm .collective_rpc .remote ("check_weights_changed" , args = tuple ()))
0 commit comments