Skip to content

Commit 92e793d

Browse files
authored
[core] LLM.collective_rpc interface and RLHF example (#12084)
Signed-off-by: youkaichao <[email protected]>
1 parent bf53e0c commit 92e793d

File tree

6 files changed

+270
-35
lines changed

6 files changed

+270
-35
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,11 +126,15 @@ steps:
126126
- tests/distributed
127127
- tests/spec_decode/e2e/test_integration_dist_tp4
128128
- tests/compile
129+
- examples/offline_inference/rlhf.py
129130
commands:
130131
- pytest -v -s distributed/test_utils.py
131132
- pytest -v -s compile/test_basic_correctness.py
132133
- pytest -v -s distributed/test_pynccl.py
133134
- pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py
135+
# TODO: create a dedicated test section for multi-GPU example tests
136+
# when we have multiple distributed example tests
137+
- python3 ../examples/offline_inference/rlhf.py
134138

135139
- label: Metrics, Tracing Test # 10min
136140
num_gpus: 2

examples/offline_inference/rlhf.py

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
"""
2+
a simple demonstration of RLHF with vLLM, inspired by
3+
the OpenRLHF framework https://github.com/OpenRLHF/OpenRLHF .
4+
It follows the design that, training processes and inference processes
5+
are different, and they live on different GPUs.
6+
Training processes send prompts to inference processes to generate data,
7+
and also synchronize the weights of the model by broadcasting the weights
8+
from the training process to the inference process.
9+
Note that this is a simple demonstration of one training instance and one
10+
inference instance. In practice, there could be multiple training instances
11+
and multiple inference instances. For the full implementation, please refer
12+
to the OpenRLHF framework.
13+
"""
14+
import os
15+
16+
import ray
17+
import torch
18+
from ray.util.placement_group import placement_group
19+
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
20+
from transformers import AutoModelForCausalLM
21+
22+
from vllm import LLM, SamplingParams, configure_as_vllm_process
23+
from vllm.utils import get_ip, get_open_port
24+
from vllm.worker.worker import Worker
25+
26+
27+
def stateless_init_process_group(master_address, master_port, rank, world_size,
28+
device):
29+
"""
30+
vLLM provides `StatelessProcessGroup` to create a process group
31+
without considering the global process group in torch.distributed.
32+
It is recommended to create `StatelessProcessGroup`, and then initialize
33+
the data-plane communication (NCCL) between external (train processes)
34+
and vLLM workers.
35+
"""
36+
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
37+
from vllm.distributed.utils import StatelessProcessGroup
38+
pg = StatelessProcessGroup.create(host=master_address,
39+
port=master_port,
40+
rank=rank,
41+
world_size=world_size)
42+
pynccl = PyNcclCommunicator(pg, device=device)
43+
return pynccl
44+
45+
46+
class MyWorker(Worker):
47+
"""
48+
The `MyWorker` class inherits from `Worker` to provide custom functions.
49+
For simplicity, we define the `MyWorker` class in this self-contained
50+
script. Normally, we should define the `MyWorker` class in a separate
51+
file and pass the qualified name of the class to the `worker_cls`
52+
parameter.
53+
"""
54+
55+
def init_weight_update_group(self, master_address, master_port,
56+
rank_offset, world_size):
57+
from vllm.distributed.parallel_state import get_world_group
58+
rank = get_world_group().rank + rank_offset
59+
self.model_update_group = stateless_init_process_group(
60+
master_address,
61+
master_port,
62+
rank,
63+
world_size,
64+
self.device,
65+
)
66+
67+
def update_weight(self, name, dtype, shape):
68+
weight = torch.empty(shape, dtype=dtype, device="cuda")
69+
self.model_update_group.broadcast(weight,
70+
src=0,
71+
stream=torch.cuda.current_stream())
72+
73+
self.model_runner.model.load_weights(weights=[(name, weight)])
74+
75+
del weight
76+
77+
def check_weights_changed(self):
78+
"""
79+
Check if the weights are updated to 0.
80+
"""
81+
weights_updated = True
82+
for name, p in self.model_runner.model.named_parameters():
83+
weights_updated = weights_updated and torch.allclose(
84+
p, torch.zeros_like(p))
85+
return weights_updated
86+
87+
88+
class MyLLM(LLM):
89+
90+
def __init__(self, *args, **kwargs):
91+
# a hack to make the script work.
92+
# stop ray from manipulating CUDA_VISIBLE_DEVICES
93+
# at the top-level
94+
del os.environ["CUDA_VISIBLE_DEVICES"]
95+
super().__init__(*args, **kwargs)
96+
97+
98+
"""
99+
Start the training process, here we use huggingface transformers
100+
as an example to hold a model on GPU 0.
101+
102+
It is important for all the processes outside of vLLM to call
103+
`configure_as_vllm_process` to set some common environment variables
104+
the same as vLLM workers.
105+
"""
106+
configure_as_vllm_process()
107+
108+
train_model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
109+
train_model.to("cuda:0")
110+
"""
111+
Start the inference process, here we use vLLM to hold a model on GPU 1 and
112+
GPU 2. For the details on how to use ray, please refer to the ray
113+
documentation https://docs.ray.io/en/latest/ .
114+
"""
115+
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2"
116+
ray.init()
117+
118+
pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 2)
119+
ray.get(pg_inference.ready())
120+
scheduling_inference = PlacementGroupSchedulingStrategy(
121+
placement_group=pg_inference,
122+
placement_group_capture_child_tasks=True,
123+
placement_group_bundle_index=0,
124+
)
125+
"""
126+
launch the vLLM inference engine.
127+
here we use `enforce_eager` to reduce the start time.
128+
"""
129+
llm = ray.remote(
130+
num_cpus=0,
131+
num_gpus=0,
132+
scheduling_strategy=scheduling_inference,
133+
)(MyLLM).remote(
134+
model="facebook/opt-125m",
135+
enforce_eager=True,
136+
worker_cls=MyWorker,
137+
tensor_parallel_size=2,
138+
distributed_executor_backend="ray",
139+
)
140+
141+
# Generate texts from the prompts.
142+
prompts = [
143+
"Hello, my name is",
144+
"The president of the United States is",
145+
"The capital of France is",
146+
"The future of AI is",
147+
]
148+
149+
sampling_params = SamplingParams(temperature=0)
150+
151+
outputs = ray.get(llm.generate.remote(prompts, sampling_params))
152+
153+
for output in outputs:
154+
prompt = output.prompt
155+
generated_text = output.outputs[0].text
156+
print(f"Prompt: {prompt!r}, "
157+
f"Generated text: {generated_text!r}")
158+
159+
# set up the communication between the training process
160+
# and the inference engine.
161+
master_address = get_ip()
162+
master_port = get_open_port()
163+
164+
handle = llm.collective_rpc.remote("init_weight_update_group",
165+
args=(master_address, master_port, 1, 3))
166+
model_update_group = stateless_init_process_group(master_address, master_port,
167+
0, 3, torch.device("cuda:0"))
168+
ray.get(handle)
169+
170+
# simulate training, modify the weights of the model.
171+
for name, p in train_model.named_parameters():
172+
p.data.zero_()
173+
174+
# sync weight from the training process to the inference engine.
175+
for name, p in train_model.named_parameters():
176+
handle = llm.collective_rpc.remote("update_weight",
177+
args=(name, p.dtype, p.shape))
178+
model_update_group.broadcast(p, src=0, stream=torch.cuda.current_stream())
179+
ray.get(handle)
180+
181+
# check if the weights are updated.
182+
assert all(ray.get(llm.collective_rpc.remote("check_weights_changed")))
183+
184+
# use the updated model to generate texts, they will be nonsense
185+
# because the weights are all zeros.
186+
outputs_updated = ray.get(llm.generate.remote(prompts, sampling_params))
187+
for output in outputs_updated:
188+
prompt = output.prompt
189+
generated_text = output.outputs[0].text
190+
print(f"Prompt: {prompt!r}, "
191+
f"Generated text: {generated_text!r}")

vllm/__init__.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,44 @@
1717

1818
from .version import __version__, __version_tuple__
1919

20+
21+
def configure_as_vllm_process():
22+
"""
23+
set some common config/environment variables that should be set
24+
for all processes created by vllm and all processes
25+
that interact with vllm workers.
26+
"""
27+
import os
28+
29+
import torch
30+
31+
# see https://github.com/NVIDIA/nccl/issues/1234
32+
os.environ['NCCL_CUMEM_ENABLE'] = '0'
33+
34+
# see https://github.com/vllm-project/vllm/issues/10480
35+
os.environ['TORCHINDUCTOR_COMPILE_THREADS'] = '1'
36+
# see https://github.com/vllm-project/vllm/issues/10619
37+
torch._inductor.config.compile_threads = 1
38+
39+
from vllm.platforms import current_platform
40+
41+
if current_platform.is_xpu():
42+
# see https://github.com/pytorch/pytorch/blob/43c5f59/torch/_dynamo/config.py#L158
43+
torch._dynamo.config.disable = True
44+
elif current_platform.is_hpu():
45+
# NOTE(kzawora): PT HPU lazy backend (PT_HPU_LAZY_MODE = 1)
46+
# does not support torch.compile
47+
# Eager backend (PT_HPU_LAZY_MODE = 0) must be selected for
48+
# torch.compile support
49+
is_lazy = os.environ.get('PT_HPU_LAZY_MODE', '1') == '1'
50+
if is_lazy:
51+
torch._dynamo.config.disable = True
52+
# NOTE(kzawora) multi-HPU inference with HPUGraphs (lazy-only)
53+
# requires enabling lazy collectives
54+
# see https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_HPU_Graphs.html # noqa: E501
55+
os.environ['PT_HPU_ENABLE_LAZY_COLLECTIVES'] = 'true'
56+
57+
2058
__all__ = [
2159
"__version__",
2260
"__version_tuple__",
@@ -42,4 +80,5 @@
4280
"AsyncEngineArgs",
4381
"initialize_ray_cluster",
4482
"PoolingParams",
83+
"configure_as_vllm_process",
4584
]

vllm/entrypoints/llm.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Type,
55
Union, cast, overload)
66

7+
import cloudpickle
78
from tqdm import tqdm
89
from typing_extensions import deprecated
910

@@ -186,6 +187,13 @@ def __init__(
186187
if "disable_log_stats" not in kwargs:
187188
kwargs["disable_log_stats"] = True
188189

190+
if "worker_cls" in kwargs:
191+
worker_cls = kwargs["worker_cls"]
192+
# if the worker_cls is not qualified string name,
193+
# we serialize it using cloudpickle to avoid pickling issues
194+
if isinstance(worker_cls, type):
195+
kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)
196+
189197
if compilation_config is not None:
190198
if isinstance(compilation_config, (int, dict)):
191199
compilation_config_instance = CompilationConfig.from_cli(
@@ -455,6 +463,23 @@ def generate(
455463
outputs = self._run_engine(use_tqdm=use_tqdm)
456464
return self.engine_class.validate_outputs(outputs, RequestOutput)
457465

466+
def collective_rpc(self,
467+
method: str,
468+
timeout: Optional[float] = None,
469+
args: Tuple = (),
470+
kwargs: Optional[Dict] = None) -> List[Any]:
471+
"""
472+
Run a method on all workers, with homogeneous arguments.
473+
The main extension point for the LLM entrypoint.
474+
Users can provide custom worker class through `worker_cls`
475+
argument, and implement new methods in the worker class.
476+
Then, users can call the new methods through this API.
477+
It is recommended to use this API to only pass control messages,
478+
and set up data-plane communication to pass data.
479+
"""
480+
return self.llm_engine.model_executor.collective_rpc(
481+
method, timeout, args, kwargs)
482+
458483
def beam_search(
459484
self,
460485
prompts: List[Union[TokensPrompt, TextPrompt]],

vllm/plugins/__init__.py

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
import logging
2-
import os
32
from typing import Callable, Dict
43

5-
import torch
6-
74
import vllm.envs as envs
85

96
logger = logging.getLogger(__name__)
@@ -50,34 +47,6 @@ def load_general_plugins():
5047
processes. They should be designed in a way that they can be loaded
5148
multiple times without causing issues.
5249
"""
53-
54-
# all processes created by vllm will load plugins,
55-
# and here we can inject some common environment variables
56-
# for all processes.
57-
58-
# see https://github.com/vllm-project/vllm/issues/10480
59-
os.environ['TORCHINDUCTOR_COMPILE_THREADS'] = '1'
60-
# see https://github.com/vllm-project/vllm/issues/10619
61-
torch._inductor.config.compile_threads = 1
62-
63-
from vllm.platforms import current_platform
64-
65-
if current_platform.is_xpu():
66-
# see https://github.com/pytorch/pytorch/blob/43c5f59/torch/_dynamo/config.py#L158
67-
torch._dynamo.config.disable = True
68-
if current_platform.is_hpu():
69-
# NOTE(kzawora): PT HPU lazy backend (PT_HPU_LAZY_MODE = 1)
70-
# does not support torch.compile
71-
# Eager backend (PT_HPU_LAZY_MODE = 0) must be selected for
72-
# torch.compile support
73-
is_lazy = os.environ.get('PT_HPU_LAZY_MODE', '1') == '1'
74-
if is_lazy:
75-
torch._dynamo.config.disable = True
76-
# NOTE(kzawora) multi-HPU inference with HPUGraphs (lazy-only)
77-
# requires enabling lazy collectives
78-
# see https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_HPU_Graphs.html # noqa: E501
79-
os.environ['PT_HPU_ENABLE_LAZY_COLLECTIVES'] = 'true'
80-
8150
global plugins_loaded
8251
if plugins_loaded:
8352
return

vllm/worker/worker_base.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from abc import ABC, abstractmethod
55
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
66

7+
import cloudpickle
78
import torch
89

910
from vllm.config import ObservabilityConfig, VllmConfig
@@ -521,14 +522,20 @@ def init_worker(self, all_kwargs: List[Dict[str, Any]]) -> None:
521522
kwargs = all_kwargs[self.rpc_rank]
522523
enable_trace_function_call_for_thread(self.vllm_config)
523524

524-
# see https://github.com/NVIDIA/nccl/issues/1234
525-
os.environ['NCCL_CUMEM_ENABLE'] = '0'
525+
from vllm import configure_as_vllm_process
526+
configure_as_vllm_process()
526527

527528
from vllm.plugins import load_general_plugins
528529
load_general_plugins()
529530

530-
worker_class = resolve_obj_by_qualname(
531-
self.vllm_config.parallel_config.worker_cls)
531+
if isinstance(self.vllm_config.parallel_config.worker_cls, str):
532+
worker_class = resolve_obj_by_qualname(
533+
self.vllm_config.parallel_config.worker_cls)
534+
else:
535+
assert isinstance(self.vllm_config.parallel_config.worker_cls,
536+
bytes)
537+
worker_class = cloudpickle.loads(
538+
self.vllm_config.parallel_config.worker_cls)
532539
self.worker = worker_class(**kwargs)
533540
assert self.worker is not None
534541

0 commit comments

Comments
 (0)