Skip to content

Commit 0a300b8

Browse files
youkaichaoshreyankg
authored andcommitted
[RLHF] use worker_extension_cls for compatibility with V0 and V1 (vllm-project#14185)
Signed-off-by: youkaichao <[email protected]>
1 parent 1ede7ce commit 0a300b8

File tree

7 files changed

+153
-100
lines changed

7 files changed

+153
-100
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,10 @@ steps:
145145
- pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py
146146
# TODO: create a dedicated test section for multi-GPU example tests
147147
# when we have multiple distributed example tests
148-
- python3 ../examples/offline_inference/rlhf.py
149-
- RAY_DEDUP_LOGS=0 python3 ../examples/offline_inference/rlhf_colocate.py
148+
- pushd ../examples/offline_inference
149+
- python3 rlhf.py
150+
- RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py
151+
- popd
150152

151153
- label: Metrics, Tracing Test # 10min
152154
num_gpus: 2

examples/offline_inference/rlhf.py

Lines changed: 3 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -18,72 +18,11 @@
1818
import torch
1919
from ray.util.placement_group import placement_group
2020
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
21+
from rlhf_utils import stateless_init_process_group
2122
from transformers import AutoModelForCausalLM
2223

2324
from vllm import LLM, SamplingParams
2425
from vllm.utils import get_ip, get_open_port
25-
from vllm.worker.worker import Worker
26-
27-
28-
def stateless_init_process_group(master_address, master_port, rank, world_size,
29-
device):
30-
"""
31-
vLLM provides `StatelessProcessGroup` to create a process group
32-
without considering the global process group in torch.distributed.
33-
It is recommended to create `StatelessProcessGroup`, and then initialize
34-
the data-plane communication (NCCL) between external (train processes)
35-
and vLLM workers.
36-
"""
37-
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
38-
from vllm.distributed.utils import StatelessProcessGroup
39-
pg = StatelessProcessGroup.create(host=master_address,
40-
port=master_port,
41-
rank=rank,
42-
world_size=world_size)
43-
pynccl = PyNcclCommunicator(pg, device=device)
44-
return pynccl
45-
46-
47-
class MyWorker(Worker):
48-
"""
49-
The `MyWorker` class inherits from `Worker` to provide custom functions.
50-
For simplicity, we define the `MyWorker` class in this self-contained
51-
script. Normally, we should define the `MyWorker` class in a separate
52-
file and pass the qualified name of the class to the `worker_cls`
53-
parameter.
54-
"""
55-
56-
def init_weight_update_group(self, master_address, master_port,
57-
rank_offset, world_size):
58-
from vllm.distributed.parallel_state import get_world_group
59-
rank = get_world_group().rank + rank_offset
60-
self.model_update_group = stateless_init_process_group(
61-
master_address,
62-
master_port,
63-
rank,
64-
world_size,
65-
self.device,
66-
)
67-
68-
def update_weight(self, name, dtype, shape):
69-
weight = torch.empty(shape, dtype=dtype, device="cuda")
70-
self.model_update_group.broadcast(weight,
71-
src=0,
72-
stream=torch.cuda.current_stream())
73-
74-
self.model_runner.model.load_weights(weights=[(name, weight)])
75-
76-
del weight
77-
78-
def check_weights_changed(self):
79-
"""
80-
Check if the weights are updated to 0.
81-
"""
82-
weights_updated = True
83-
for name, p in self.model_runner.model.named_parameters():
84-
weights_updated = weights_updated and torch.allclose(
85-
p, torch.zeros_like(p))
86-
return weights_updated
8726

8827

8928
class MyLLM(LLM):
@@ -129,7 +68,7 @@ def __init__(self, *args, **kwargs):
12968
)(MyLLM).remote(
13069
model="facebook/opt-125m",
13170
enforce_eager=True,
132-
worker_cls=MyWorker,
71+
worker_extension_cls="rlhf_utils.WorkerExtension",
13372
tensor_parallel_size=2,
13473
distributed_executor_backend="ray",
13574
)
@@ -159,6 +98,7 @@ def __init__(self, *args, **kwargs):
15998

16099
handle = llm.collective_rpc.remote("init_weight_update_group",
161100
args=(master_address, master_port, 1, 3))
101+
162102
model_update_group = stateless_init_process_group(master_address, master_port,
163103
0, 3, torch.device("cuda:0"))
164104
ray.get(handle)

examples/offline_inference/rlhf_colocate.py

Lines changed: 1 addition & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -17,40 +17,6 @@
1717
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
1818

1919
from vllm import LLM
20-
from vllm.worker.worker import Worker
21-
22-
23-
class MyWorker(Worker):
24-
25-
def report_device_id(self) -> str:
26-
from vllm.platforms import current_platform
27-
self.device_uuid = current_platform.get_device_uuid(self.device.index)
28-
return self.device_uuid
29-
30-
def update_weights_from_ipc_handles(self, ipc_handles):
31-
handles = ipc_handles[self.device_uuid]
32-
device_id = self.device.index
33-
weights = []
34-
for name, handle in handles.items():
35-
func, args = handle
36-
list_args = list(args)
37-
# the key is to change device id to the current device id
38-
# in case two processes have different CUDA_VISIBLE_DEVICES
39-
list_args[6] = device_id
40-
tensor = func(*list_args)
41-
weights.append((name, tensor))
42-
self.model_runner.model.load_weights(weights=weights)
43-
torch.cuda.synchronize()
44-
45-
def check_weights_changed(self):
46-
"""
47-
Check if the weights are updated to 0.
48-
"""
49-
weights_updated = True
50-
for name, p in self.model_runner.model.named_parameters():
51-
weights_updated = weights_updated and torch.allclose(
52-
p, torch.zeros_like(p))
53-
return weights_updated
5420

5521

5622
class MyLLM(LLM):
@@ -150,7 +116,7 @@ def get_weight_ipc_handles(self):
150116
)(MyLLM).remote(
151117
model="facebook/opt-125m",
152118
enforce_eager=True,
153-
worker_cls=MyWorker,
119+
worker_extension_cls="rlhf_utils.ColocateWorkerExtension",
154120
tensor_parallel_size=2,
155121
distributed_executor_backend="ray",
156122
gpu_memory_utilization=0.4,
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import torch
3+
4+
5+
def stateless_init_process_group(master_address, master_port, rank, world_size,
6+
device):
7+
"""
8+
vLLM provides `StatelessProcessGroup` to create a process group
9+
without considering the global process group in torch.distributed.
10+
It is recommended to create `StatelessProcessGroup`, and then initialize
11+
the data-plane communication (NCCL) between external (train processes)
12+
and vLLM workers.
13+
"""
14+
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
15+
from vllm.distributed.utils import StatelessProcessGroup
16+
pg = StatelessProcessGroup.create(host=master_address,
17+
port=master_port,
18+
rank=rank,
19+
world_size=world_size)
20+
pynccl = PyNcclCommunicator(pg, device=device)
21+
return pynccl
22+
23+
24+
class WorkerExtension:
25+
"""
26+
The class for vLLM's worker to inherit from.
27+
By defining an extension class, the code can work no matter what is
28+
the underlying worker class. This way, the code can be compatible
29+
with both vLLM V0 and V1.
30+
NOTE: we define this class in a separate module, and the main module
31+
should pass the full qualified name as `worker_extension_cls` argument.
32+
"""
33+
34+
def init_weight_update_group(self, master_address, master_port,
35+
rank_offset, world_size):
36+
from vllm.distributed.parallel_state import get_world_group
37+
rank = get_world_group().rank + rank_offset
38+
self.model_update_group = stateless_init_process_group(
39+
master_address,
40+
master_port,
41+
rank,
42+
world_size,
43+
self.device,
44+
)
45+
46+
def update_weight(self, name, dtype, shape):
47+
weight = torch.empty(shape, dtype=dtype, device="cuda")
48+
self.model_update_group.broadcast(weight,
49+
src=0,
50+
stream=torch.cuda.current_stream())
51+
52+
self.model_runner.model.load_weights(weights=[(name, weight)])
53+
54+
del weight
55+
56+
def check_weights_changed(self):
57+
"""
58+
Check if the weights are updated to 0.
59+
"""
60+
weights_updated = True
61+
for name, p in self.model_runner.model.named_parameters():
62+
weights_updated = weights_updated and torch.allclose(
63+
p, torch.zeros_like(p))
64+
return weights_updated
65+
66+
67+
class ColocateWorkerExtension:
68+
"""
69+
The class for vLLM's worker to inherit from, in the colocate setting.
70+
By defining an extension class, the code can work no matter what is
71+
the underlying worker class. This way, the code can be compatible
72+
with both vLLM V0 and V1.
73+
NOTE: we define this class in a separate module, and the main module
74+
should pass the full qualified name as `worker_extension_cls` argument.
75+
"""
76+
77+
def report_device_id(self) -> str:
78+
from vllm.platforms import current_platform
79+
self.device_uuid = current_platform.get_device_uuid(self.device.index)
80+
return self.device_uuid
81+
82+
def update_weights_from_ipc_handles(self, ipc_handles):
83+
handles = ipc_handles[self.device_uuid]
84+
device_id = self.device.index
85+
weights = []
86+
for name, handle in handles.items():
87+
func, args = handle
88+
list_args = list(args)
89+
# the key is to change device id to the current device id
90+
# in case two processes have different CUDA_VISIBLE_DEVICES
91+
list_args[6] = device_id
92+
tensor = func(*list_args)
93+
weights.append((name, tensor))
94+
self.model_runner.model.load_weights(weights=weights)
95+
torch.cuda.synchronize()
96+
97+
def check_weights_changed(self):
98+
"""
99+
Check if the weights are updated to 0.
100+
"""
101+
weights_updated = True
102+
for name, p in self.model_runner.model.named_parameters():
103+
weights_updated = weights_updated and torch.allclose(
104+
p, torch.zeros_like(p))
105+
return weights_updated

vllm/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1366,6 +1366,7 @@ class ParallelConfig:
13661366
# will be determined based on the platform.
13671367
worker_cls: str = "auto"
13681368
sd_worker_cls: str = "auto"
1369+
worker_extension_cls: str = ""
13691370

13701371
# world_size is TPxPP, it affects the number of workers we create.
13711372
world_size: int = field(init=False)
@@ -1523,6 +1524,9 @@ def _verify_args(self) -> None:
15231524
raise ValueError("Unable to use nsight profiling unless workers "
15241525
"run with Ray.")
15251526

1527+
assert isinstance(self.worker_extension_cls, str), (
1528+
"worker_extension_cls must be a string (qualified class name).")
1529+
15261530

15271531
@dataclass
15281532
class SchedulerConfig:

vllm/engine/arg_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ class EngineArgs:
203203
override_pooler_config: Optional[PoolerConfig] = None
204204
compilation_config: Optional[CompilationConfig] = None
205205
worker_cls: str = "auto"
206+
worker_extension_cls: str = ""
206207

207208
kv_transfer_config: Optional[KVTransferConfig] = None
208209

@@ -1016,6 +1017,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
10161017
type=str,
10171018
default="auto",
10181019
help='The worker class to use for distributed execution.')
1020+
parser.add_argument(
1021+
'--worker-extension-cls',
1022+
type=str,
1023+
default="",
1024+
help='The worker extension class on top of the worker cls, '
1025+
'it is useful if you just want to add new functions to the worker '
1026+
'class without changing the existing functions.')
10191027
parser.add_argument(
10201028
"--generation-config",
10211029
type=nullable_str,
@@ -1210,6 +1218,7 @@ def create_engine_config(self,
12101218
ray_workers_use_nsight=self.ray_workers_use_nsight,
12111219
distributed_executor_backend=self.distributed_executor_backend,
12121220
worker_cls=self.worker_cls,
1221+
worker_extension_cls=self.worker_extension_cls,
12131222
)
12141223

12151224
max_model_len = model_config.max_model_len

vllm/worker/worker_base.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -558,10 +558,37 @@ def init_worker(self, all_kwargs: List[Dict[str, Any]]) -> None:
558558
worker_class = resolve_obj_by_qualname(
559559
self.vllm_config.parallel_config.worker_cls)
560560
else:
561+
logger.warning(
562+
"passing worker_cls as a class object is strongly deprecated,"
563+
" as the serialization of class objects can be tricky and"
564+
" error-prone. To be safe, please keep the class in a separate"
565+
" module and pass the qualified name of the class as a string."
566+
)
561567
assert isinstance(self.vllm_config.parallel_config.worker_cls,
562568
bytes)
563569
worker_class = cloudpickle.loads(
564570
self.vllm_config.parallel_config.worker_cls)
571+
if self.vllm_config.parallel_config.worker_extension_cls:
572+
worker_extension_cls = resolve_obj_by_qualname(
573+
self.vllm_config.parallel_config.worker_extension_cls)
574+
extended_calls = []
575+
if worker_extension_cls not in worker_class.__bases__:
576+
# check any conflicts between worker and worker_extension_cls
577+
for attr in dir(worker_extension_cls):
578+
if attr.startswith("__"):
579+
continue
580+
assert not hasattr(worker_class, attr), (
581+
f"Worker class {worker_class} already has an attribute"
582+
f" {attr}, which conflicts with the worker"
583+
f" extension class {worker_extension_cls}.")
584+
if callable(getattr(worker_extension_cls, attr)):
585+
extended_calls.append(attr)
586+
# dynamically inherit the worker extension class
587+
worker_class.__bases__ = worker_class.__bases__ + (
588+
worker_extension_cls, )
589+
logger.info(
590+
"Injected %s into %s for extended collective_rpc calls %s",
591+
worker_extension_cls, worker_class, extended_calls)
565592
with set_current_vllm_config(self.vllm_config):
566593
# To make vLLM config available during worker initialization
567594
self.worker = worker_class(**kwargs)

0 commit comments

Comments
 (0)