|
| 1 | +""" |
| 2 | +a simple demonstration of RLHF with vLLM, inspired by |
| 3 | +the OpenRLHF framework https://github.com/OpenRLHF/OpenRLHF . |
| 4 | +It follows the design that, training processes and inference processes |
| 5 | +are different, and they live on different GPUs. |
| 6 | +Training processes send prompts to inference processes to generate data, |
| 7 | +and also synchronize the weights of the model by broadcasting the weights |
| 8 | +from the training process to the inference process. |
| 9 | +Note that this is a simple demonstration of one training instance and one |
| 10 | +inference instance. In practice, there could be multiple training instances |
| 11 | +and multiple inference instances. For the full implementation, please refer |
| 12 | +to the OpenRLHF framework. |
| 13 | +""" |
| 14 | +import os |
| 15 | + |
| 16 | +import ray |
| 17 | +import torch |
| 18 | +from ray.util.placement_group import placement_group |
| 19 | +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy |
| 20 | +from transformers import AutoModelForCausalLM |
| 21 | + |
| 22 | +from vllm import LLM, SamplingParams, configure_as_vllm_process |
| 23 | +from vllm.utils import get_ip, get_open_port |
| 24 | +from vllm.worker.worker import Worker |
| 25 | + |
| 26 | + |
| 27 | +def stateless_init_process_group(master_address, master_port, rank, world_size, |
| 28 | + device): |
| 29 | + """ |
| 30 | + vLLM provides `StatelessProcessGroup` to create a process group |
| 31 | + without considering the global process group in torch.distributed. |
| 32 | + It is recommended to create `StatelessProcessGroup`, and then initialize |
| 33 | + the data-plane communication (NCCL) between external (train processes) |
| 34 | + and vLLM workers. |
| 35 | + """ |
| 36 | + from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator |
| 37 | + from vllm.distributed.utils import StatelessProcessGroup |
| 38 | + pg = StatelessProcessGroup.create(host=master_address, |
| 39 | + port=master_port, |
| 40 | + rank=rank, |
| 41 | + world_size=world_size) |
| 42 | + pynccl = PyNcclCommunicator(pg, device=device) |
| 43 | + return pynccl |
| 44 | + |
| 45 | + |
| 46 | +class MyWorker(Worker): |
| 47 | + """ |
| 48 | + The `MyWorker` class inherits from `Worker` to provide custom functions. |
| 49 | + For simplicity, we define the `MyWorker` class in this self-contained |
| 50 | + script. Normally, we should define the `MyWorker` class in a separate |
| 51 | + file and pass the qualified name of the class to the `worker_cls` |
| 52 | + parameter. |
| 53 | + """ |
| 54 | + |
| 55 | + def init_weight_update_group(self, master_address, master_port, |
| 56 | + rank_offset, world_size): |
| 57 | + from vllm.distributed.parallel_state import get_world_group |
| 58 | + rank = get_world_group().rank + rank_offset |
| 59 | + self.model_update_group = stateless_init_process_group( |
| 60 | + master_address, |
| 61 | + master_port, |
| 62 | + rank, |
| 63 | + world_size, |
| 64 | + self.device, |
| 65 | + ) |
| 66 | + |
| 67 | + def update_weight(self, name, dtype, shape): |
| 68 | + weight = torch.empty(shape, dtype=dtype, device="cuda") |
| 69 | + self.model_update_group.broadcast(weight, |
| 70 | + src=0, |
| 71 | + stream=torch.cuda.current_stream()) |
| 72 | + |
| 73 | + self.model_runner.model.load_weights(weights=[(name, weight)]) |
| 74 | + |
| 75 | + del weight |
| 76 | + |
| 77 | + def check_weights_changed(self): |
| 78 | + """ |
| 79 | + Check if the weights are updated to 0. |
| 80 | + """ |
| 81 | + weights_updated = True |
| 82 | + for name, p in self.model_runner.model.named_parameters(): |
| 83 | + weights_updated = weights_updated and torch.allclose( |
| 84 | + p, torch.zeros_like(p)) |
| 85 | + return weights_updated |
| 86 | + |
| 87 | + |
| 88 | +class MyLLM(LLM): |
| 89 | + |
| 90 | + def __init__(self, *args, **kwargs): |
| 91 | + # a hack to make the script work. |
| 92 | + # stop ray from manipulating CUDA_VISIBLE_DEVICES |
| 93 | + # at the top-level |
| 94 | + del os.environ["CUDA_VISIBLE_DEVICES"] |
| 95 | + super().__init__(*args, **kwargs) |
| 96 | + |
| 97 | + |
| 98 | +""" |
| 99 | +Start the training process, here we use huggingface transformers |
| 100 | +as an example to hold a model on GPU 0. |
| 101 | +
|
| 102 | +It is important for all the processes outside of vLLM to call |
| 103 | +`configure_as_vllm_process` to set some common environment variables |
| 104 | +the same as vLLM workers. |
| 105 | +""" |
| 106 | +configure_as_vllm_process() |
| 107 | + |
| 108 | +train_model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m") |
| 109 | +train_model.to("cuda:0") |
| 110 | +""" |
| 111 | +Start the inference process, here we use vLLM to hold a model on GPU 1 and |
| 112 | +GPU 2. For the details on how to use ray, please refer to the ray |
| 113 | +documentation https://docs.ray.io/en/latest/ . |
| 114 | +""" |
| 115 | +os.environ["CUDA_VISIBLE_DEVICES"] = "1,2" |
| 116 | +ray.init() |
| 117 | + |
| 118 | +pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 2) |
| 119 | +ray.get(pg_inference.ready()) |
| 120 | +scheduling_inference = PlacementGroupSchedulingStrategy( |
| 121 | + placement_group=pg_inference, |
| 122 | + placement_group_capture_child_tasks=True, |
| 123 | + placement_group_bundle_index=0, |
| 124 | +) |
| 125 | +""" |
| 126 | +launch the vLLM inference engine. |
| 127 | +here we use `enforce_eager` to reduce the start time. |
| 128 | +""" |
| 129 | +llm = ray.remote( |
| 130 | + num_cpus=0, |
| 131 | + num_gpus=0, |
| 132 | + scheduling_strategy=scheduling_inference, |
| 133 | +)(MyLLM).remote( |
| 134 | + model="facebook/opt-125m", |
| 135 | + enforce_eager=True, |
| 136 | + worker_cls=MyWorker, |
| 137 | + tensor_parallel_size=2, |
| 138 | + distributed_executor_backend="ray", |
| 139 | +) |
| 140 | + |
| 141 | +# Generate texts from the prompts. |
| 142 | +prompts = [ |
| 143 | + "Hello, my name is", |
| 144 | + "The president of the United States is", |
| 145 | + "The capital of France is", |
| 146 | + "The future of AI is", |
| 147 | +] |
| 148 | + |
| 149 | +sampling_params = SamplingParams(temperature=0) |
| 150 | + |
| 151 | +outputs = ray.get(llm.generate.remote(prompts, sampling_params)) |
| 152 | + |
| 153 | +for output in outputs: |
| 154 | + prompt = output.prompt |
| 155 | + generated_text = output.outputs[0].text |
| 156 | + print(f"Prompt: {prompt!r}, " |
| 157 | + f"Generated text: {generated_text!r}") |
| 158 | + |
| 159 | +# set up the communication between the training process |
| 160 | +# and the inference engine. |
| 161 | +master_address = get_ip() |
| 162 | +master_port = get_open_port() |
| 163 | + |
| 164 | +handle = llm.collective_rpc.remote("init_weight_update_group", |
| 165 | + args=(master_address, master_port, 1, 3)) |
| 166 | +model_update_group = stateless_init_process_group(master_address, master_port, |
| 167 | + 0, 3, torch.device("cuda:0")) |
| 168 | +ray.get(handle) |
| 169 | + |
| 170 | +# simulate training, modify the weights of the model. |
| 171 | +for name, p in train_model.named_parameters(): |
| 172 | + p.data.zero_() |
| 173 | + |
| 174 | +# sync weight from the training process to the inference engine. |
| 175 | +for name, p in train_model.named_parameters(): |
| 176 | + handle = llm.collective_rpc.remote("update_weight", |
| 177 | + args=(name, p.dtype, p.shape)) |
| 178 | + model_update_group.broadcast(p, src=0, stream=torch.cuda.current_stream()) |
| 179 | + ray.get(handle) |
| 180 | + |
| 181 | +# check if the weights are updated. |
| 182 | +assert all(ray.get(llm.collective_rpc.remote("check_weights_changed"))) |
| 183 | + |
| 184 | +# use the updated model to generate texts, they will be nonsense |
| 185 | +# because the weights are all zeros. |
| 186 | +outputs_updated = ray.get(llm.generate.remote(prompts, sampling_params)) |
| 187 | +for output in outputs_updated: |
| 188 | + prompt = output.prompt |
| 189 | + generated_text = output.outputs[0].text |
| 190 | + print(f"Prompt: {prompt!r}, " |
| 191 | + f"Generated text: {generated_text!r}") |
0 commit comments