-
-
Notifications
You must be signed in to change notification settings - Fork 7.7k
[RLHF] use worker_extension_cls for compatibility with V0 and V1 #14185
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
7d0cae3
c7cfb53
52c9f0e
f7bb5d7
1525a3a
08416a6
f5db641
dd7b3c3
722c9ad
731d4f6
ae2d12f
3661da3
1b43146
00c6adc
86743da
6d9f76b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
import torch | ||
|
||
|
||
def stateless_init_process_group(master_address, master_port, rank, world_size, | ||
device): | ||
""" | ||
vLLM provides `StatelessProcessGroup` to create a process group | ||
without considering the global process group in torch.distributed. | ||
It is recommended to create `StatelessProcessGroup`, and then initialize | ||
the data-plane communication (NCCL) between external (train processes) | ||
and vLLM workers. | ||
""" | ||
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator | ||
from vllm.distributed.utils import StatelessProcessGroup | ||
pg = StatelessProcessGroup.create(host=master_address, | ||
port=master_port, | ||
rank=rank, | ||
world_size=world_size) | ||
pynccl = PyNcclCommunicator(pg, device=device) | ||
return pynccl | ||
|
||
|
||
class WorkerAdapter: | ||
""" | ||
The class for vLLM's worker to inherit from. | ||
By defining an adapter, the code can work no matter what is | ||
the underlying worker class. This way, the code can be compatible | ||
with both vLLM V0 and V1. | ||
NOTE: we define this class in a separate module, and the main module | ||
should pass the full qualified name as `worker_adapter_cls` argument. | ||
""" | ||
|
||
def init_weight_update_group(self, master_address, master_port, | ||
rank_offset, world_size): | ||
from vllm.distributed.parallel_state import get_world_group | ||
rank = get_world_group().rank + rank_offset | ||
self.model_update_group = stateless_init_process_group( | ||
master_address, | ||
master_port, | ||
rank, | ||
world_size, | ||
self.device, | ||
) | ||
|
||
def update_weight(self, name, dtype, shape): | ||
weight = torch.empty(shape, dtype=dtype, device="cuda") | ||
self.model_update_group.broadcast(weight, | ||
src=0, | ||
stream=torch.cuda.current_stream()) | ||
|
||
self.model_runner.model.load_weights(weights=[(name, weight)]) | ||
|
||
del weight | ||
|
||
def check_weights_changed(self): | ||
""" | ||
Check if the weights are updated to 0. | ||
""" | ||
weights_updated = True | ||
for name, p in self.model_runner.model.named_parameters(): | ||
weights_updated = weights_updated and torch.allclose( | ||
p, torch.zeros_like(p)) | ||
return weights_updated | ||
|
||
|
||
class ColocateWorkerAdapter: | ||
""" | ||
The class for vLLM's worker to inherit from, in the colocate setting. | ||
By defining an adapter, the code can work no matter what is | ||
the underlying worker class. This way, the code can be compatible | ||
with both vLLM V0 and V1. | ||
NOTE: we define this class in a separate module, and the main module | ||
should pass the full qualified name as `worker_adapter_cls` argument. | ||
""" | ||
|
||
def report_device_id(self) -> str: | ||
from vllm.platforms import current_platform | ||
self.device_uuid = current_platform.get_device_uuid(self.device.index) | ||
return self.device_uuid | ||
|
||
def update_weights_from_ipc_handles(self, ipc_handles): | ||
handles = ipc_handles[self.device_uuid] | ||
device_id = self.device.index | ||
weights = [] | ||
for name, handle in handles.items(): | ||
func, args = handle | ||
list_args = list(args) | ||
# the key is to change device id to the current device id | ||
# in case two processes have different CUDA_VISIBLE_DEVICES | ||
list_args[6] = device_id | ||
tensor = func(*list_args) | ||
weights.append((name, tensor)) | ||
self.model_runner.model.load_weights(weights=weights) | ||
torch.cuda.synchronize() | ||
|
||
def check_weights_changed(self): | ||
""" | ||
Check if the weights are updated to 0. | ||
""" | ||
weights_updated = True | ||
for name, p in self.model_runner.model.named_parameters(): | ||
weights_updated = weights_updated and torch.allclose( | ||
p, torch.zeros_like(p)) | ||
return weights_updated |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -558,10 +558,34 @@ def init_worker(self, all_kwargs: List[Dict[str, Any]]) -> None: | |
worker_class = resolve_obj_by_qualname( | ||
self.vllm_config.parallel_config.worker_cls) | ||
else: | ||
logger.warning( | ||
"passing worker_cls as a class object is strongly deprecated," | ||
" as the serialization of class objects can be tricky and" | ||
" error-prone. To be safe, please keep the class in a separate" | ||
" module and pass the qualified name of the class as a string." | ||
) | ||
assert isinstance(self.vllm_config.parallel_config.worker_cls, | ||
bytes) | ||
worker_class = cloudpickle.loads( | ||
self.vllm_config.parallel_config.worker_cls) | ||
if self.vllm_config.parallel_config.worker_adapter_cls: | ||
worker_adapter_class = resolve_obj_by_qualname( | ||
self.vllm_config.parallel_config.worker_adapter_cls) | ||
logger.info( | ||
"Injecting %s into %s for extended collective_rpc call", | ||
worker_adapter_class, worker_class) | ||
if worker_adapter_class not in worker_class.__bases__: | ||
# check any conflicts between worker and worker_adapter | ||
for attr in dir(worker_adapter_class): | ||
if attr.startswith("__"): | ||
continue | ||
assert not hasattr(worker_class, attr), ( | ||
f"Worker class {worker_class} already has an attribute" | ||
f" {attr}, which conflicts with the worker" | ||
f" adapter class {worker_adapter_class}.") | ||
# dynamically inherit the worker adapter class | ||
worker_class.__bases__ = worker_class.__bases__ + ( | ||
worker_adapter_class, ) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is wild There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. honestly pretty cool that you can do this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. any suggestion to make it better? the problem is, the worker class is dynamically determined based on v1 or v0. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you think we'd be able to simplify this once we move off of V0 completely? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think so, this is not just for v0/v1, but also for different platforms. for example, if an rlhf framework wants to support both cuda and rocm, they can use a similar mixin class for both, without knowing what is the underlying worker class. |
||
with set_current_vllm_config(self.vllm_config): | ||
# To make vLLM config available during worker initialization | ||
self.worker = worker_class(**kwargs) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How are
update_weight
andcheck_weights_changed
supposed to be used?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
update_weight
is called externally, bycollective_rpc("update_weight")