Skip to content

[RLHF] use worker_extension_cls for compatibility with V0 and V1 #14185

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Mar 6, 2025
6 changes: 4 additions & 2 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,10 @@ steps:
- pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py
# TODO: create a dedicated test section for multi-GPU example tests
# when we have multiple distributed example tests
- python3 ../examples/offline_inference/rlhf.py
- RAY_DEDUP_LOGS=0 python3 ../examples/offline_inference/rlhf_colocate.py
- pushd ../examples/offline_inference
- python3 rlhf.py
- RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py
- popd

- label: Metrics, Tracing Test # 10min
num_gpus: 2
Expand Down
66 changes: 3 additions & 63 deletions examples/offline_inference/rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,72 +18,11 @@
import torch
from ray.util.placement_group import placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from rlhf_utils import stateless_init_process_group
from transformers import AutoModelForCausalLM

from vllm import LLM, SamplingParams
from vllm.utils import get_ip, get_open_port
from vllm.worker.worker import Worker


def stateless_init_process_group(master_address, master_port, rank, world_size,
device):
"""
vLLM provides `StatelessProcessGroup` to create a process group
without considering the global process group in torch.distributed.
It is recommended to create `StatelessProcessGroup`, and then initialize
the data-plane communication (NCCL) between external (train processes)
and vLLM workers.
"""
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.utils import StatelessProcessGroup
pg = StatelessProcessGroup.create(host=master_address,
port=master_port,
rank=rank,
world_size=world_size)
pynccl = PyNcclCommunicator(pg, device=device)
return pynccl


class MyWorker(Worker):
"""
The `MyWorker` class inherits from `Worker` to provide custom functions.
For simplicity, we define the `MyWorker` class in this self-contained
script. Normally, we should define the `MyWorker` class in a separate
file and pass the qualified name of the class to the `worker_cls`
parameter.
"""

def init_weight_update_group(self, master_address, master_port,
rank_offset, world_size):
from vllm.distributed.parallel_state import get_world_group
rank = get_world_group().rank + rank_offset
self.model_update_group = stateless_init_process_group(
master_address,
master_port,
rank,
world_size,
self.device,
)

def update_weight(self, name, dtype, shape):
weight = torch.empty(shape, dtype=dtype, device="cuda")
self.model_update_group.broadcast(weight,
src=0,
stream=torch.cuda.current_stream())

self.model_runner.model.load_weights(weights=[(name, weight)])

del weight

def check_weights_changed(self):
"""
Check if the weights are updated to 0.
"""
weights_updated = True
for name, p in self.model_runner.model.named_parameters():
weights_updated = weights_updated and torch.allclose(
p, torch.zeros_like(p))
return weights_updated


class MyLLM(LLM):
Expand Down Expand Up @@ -129,7 +68,7 @@ def __init__(self, *args, **kwargs):
)(MyLLM).remote(
model="facebook/opt-125m",
enforce_eager=True,
worker_cls=MyWorker,
worker_adapter_cls="rlhf_utils.WorkerAdapter",
tensor_parallel_size=2,
distributed_executor_backend="ray",
)
Expand Down Expand Up @@ -159,6 +98,7 @@ def __init__(self, *args, **kwargs):

handle = llm.collective_rpc.remote("init_weight_update_group",
args=(master_address, master_port, 1, 3))

model_update_group = stateless_init_process_group(master_address, master_port,
0, 3, torch.device("cuda:0"))
ray.get(handle)
Expand Down
36 changes: 1 addition & 35 deletions examples/offline_inference/rlhf_colocate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,40 +17,6 @@
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy

from vllm import LLM
from vllm.worker.worker import Worker


class MyWorker(Worker):

def report_device_id(self) -> str:
from vllm.platforms import current_platform
self.device_uuid = current_platform.get_device_uuid(self.device.index)
return self.device_uuid

def update_weights_from_ipc_handles(self, ipc_handles):
handles = ipc_handles[self.device_uuid]
device_id = self.device.index
weights = []
for name, handle in handles.items():
func, args = handle
list_args = list(args)
# the key is to change device id to the current device id
# in case two processes have different CUDA_VISIBLE_DEVICES
list_args[6] = device_id
tensor = func(*list_args)
weights.append((name, tensor))
self.model_runner.model.load_weights(weights=weights)
torch.cuda.synchronize()

def check_weights_changed(self):
"""
Check if the weights are updated to 0.
"""
weights_updated = True
for name, p in self.model_runner.model.named_parameters():
weights_updated = weights_updated and torch.allclose(
p, torch.zeros_like(p))
return weights_updated


class MyLLM(LLM):
Expand Down Expand Up @@ -150,7 +116,7 @@ def get_weight_ipc_handles(self):
)(MyLLM).remote(
model="facebook/opt-125m",
enforce_eager=True,
worker_cls=MyWorker,
worker_adapter_cls="rlhf_utils.ColocateWorkerAdapter",
tensor_parallel_size=2,
distributed_executor_backend="ray",
gpu_memory_utilization=0.4,
Expand Down
105 changes: 105 additions & 0 deletions examples/offline_inference/rlhf_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# SPDX-License-Identifier: Apache-2.0
import torch


def stateless_init_process_group(master_address, master_port, rank, world_size,
device):
"""
vLLM provides `StatelessProcessGroup` to create a process group
without considering the global process group in torch.distributed.
It is recommended to create `StatelessProcessGroup`, and then initialize
the data-plane communication (NCCL) between external (train processes)
and vLLM workers.
"""
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.utils import StatelessProcessGroup
pg = StatelessProcessGroup.create(host=master_address,
port=master_port,
rank=rank,
world_size=world_size)
pynccl = PyNcclCommunicator(pg, device=device)
return pynccl


class WorkerAdapter:
"""
The class for vLLM's worker to inherit from.
By defining an adapter, the code can work no matter what is
the underlying worker class. This way, the code can be compatible
with both vLLM V0 and V1.
NOTE: we define this class in a separate module, and the main module
should pass the full qualified name as `worker_adapter_cls` argument.
"""

def init_weight_update_group(self, master_address, master_port,
rank_offset, world_size):
from vllm.distributed.parallel_state import get_world_group
rank = get_world_group().rank + rank_offset
self.model_update_group = stateless_init_process_group(
master_address,
master_port,
rank,
world_size,
self.device,
)

def update_weight(self, name, dtype, shape):
weight = torch.empty(shape, dtype=dtype, device="cuda")
self.model_update_group.broadcast(weight,
src=0,
stream=torch.cuda.current_stream())

self.model_runner.model.load_weights(weights=[(name, weight)])

del weight
Comment on lines +46 to +54
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How are update_weightand check_weights_changed supposed to be used?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update_weight is called externally, by collective_rpc("update_weight")


def check_weights_changed(self):
"""
Check if the weights are updated to 0.
"""
weights_updated = True
for name, p in self.model_runner.model.named_parameters():
weights_updated = weights_updated and torch.allclose(
p, torch.zeros_like(p))
return weights_updated


class ColocateWorkerAdapter:
"""
The class for vLLM's worker to inherit from, in the colocate setting.
By defining an adapter, the code can work no matter what is
the underlying worker class. This way, the code can be compatible
with both vLLM V0 and V1.
NOTE: we define this class in a separate module, and the main module
should pass the full qualified name as `worker_adapter_cls` argument.
"""

def report_device_id(self) -> str:
from vllm.platforms import current_platform
self.device_uuid = current_platform.get_device_uuid(self.device.index)
return self.device_uuid

def update_weights_from_ipc_handles(self, ipc_handles):
handles = ipc_handles[self.device_uuid]
device_id = self.device.index
weights = []
for name, handle in handles.items():
func, args = handle
list_args = list(args)
# the key is to change device id to the current device id
# in case two processes have different CUDA_VISIBLE_DEVICES
list_args[6] = device_id
tensor = func(*list_args)
weights.append((name, tensor))
self.model_runner.model.load_weights(weights=weights)
torch.cuda.synchronize()

def check_weights_changed(self):
"""
Check if the weights are updated to 0.
"""
weights_updated = True
for name, p in self.model_runner.model.named_parameters():
weights_updated = weights_updated and torch.allclose(
p, torch.zeros_like(p))
return weights_updated
4 changes: 4 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1366,6 +1366,7 @@ class ParallelConfig:
# will be determined based on the platform.
worker_cls: str = "auto"
sd_worker_cls: str = "auto"
worker_adapter_cls: str = ""

# world_size is TPxPP, it affects the number of workers we create.
world_size: int = field(init=False)
Expand Down Expand Up @@ -1523,6 +1524,9 @@ def _verify_args(self) -> None:
raise ValueError("Unable to use nsight profiling unless workers "
"run with Ray.")

assert isinstance(self.worker_adapter_cls, str), (
"worker_adapter_cls must be a string (qualified class name).")


@dataclass
class SchedulerConfig:
Expand Down
9 changes: 9 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ class EngineArgs:
override_pooler_config: Optional[PoolerConfig] = None
compilation_config: Optional[CompilationConfig] = None
worker_cls: str = "auto"
worker_adapter_cls: str = ""

kv_transfer_config: Optional[KVTransferConfig] = None

Expand Down Expand Up @@ -1015,6 +1016,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
type=str,
default="auto",
help='The worker class to use for distributed execution.')
parser.add_argument(
'--worker-adapter-cls',
type=str,
default="",
help='The worker adapter class on top of the worker cls, '
'it is useful if you just want to add new functions to the worker '
'class without changing the existing functions.')
parser.add_argument(
"--generation-config",
type=nullable_str,
Expand Down Expand Up @@ -1209,6 +1217,7 @@ def create_engine_config(self,
ray_workers_use_nsight=self.ray_workers_use_nsight,
distributed_executor_backend=self.distributed_executor_backend,
worker_cls=self.worker_cls,
worker_adapter_cls=self.worker_adapter_cls,
)

max_model_len = model_config.max_model_len
Expand Down
24 changes: 24 additions & 0 deletions vllm/worker/worker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,10 +558,34 @@ def init_worker(self, all_kwargs: List[Dict[str, Any]]) -> None:
worker_class = resolve_obj_by_qualname(
self.vllm_config.parallel_config.worker_cls)
else:
logger.warning(
"passing worker_cls as a class object is strongly deprecated,"
" as the serialization of class objects can be tricky and"
" error-prone. To be safe, please keep the class in a separate"
" module and pass the qualified name of the class as a string."
)
assert isinstance(self.vllm_config.parallel_config.worker_cls,
bytes)
worker_class = cloudpickle.loads(
self.vllm_config.parallel_config.worker_cls)
if self.vllm_config.parallel_config.worker_adapter_cls:
worker_adapter_class = resolve_obj_by_qualname(
self.vllm_config.parallel_config.worker_adapter_cls)
logger.info(
"Injecting %s into %s for extended collective_rpc call",
worker_adapter_class, worker_class)
if worker_adapter_class not in worker_class.__bases__:
# check any conflicts between worker and worker_adapter
for attr in dir(worker_adapter_class):
if attr.startswith("__"):
continue
assert not hasattr(worker_class, attr), (
f"Worker class {worker_class} already has an attribute"
f" {attr}, which conflicts with the worker"
f" adapter class {worker_adapter_class}.")
# dynamically inherit the worker adapter class
worker_class.__bases__ = worker_class.__bases__ + (
worker_adapter_class, )
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is wild

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

honestly pretty cool that you can do this

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any suggestion to make it better? the problem is, the worker class is dynamically determined based on v1 or v0.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we'd be able to simplify this once we move off of V0 completely?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so, this is not just for v0/v1, but also for different platforms. for example, if an rlhf framework wants to support both cuda and rocm, they can use a similar mixin class for both, without knowing what is the underlying worker class.

with set_current_vllm_config(self.vllm_config):
# To make vLLM config available during worker initialization
self.worker = worker_class(**kwargs)
Expand Down