-
-
Notifications
You must be signed in to change notification settings - Fork 7.6k
[core] LLM.collective_rpc interface and RLHF example #12084
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 18 commits
Commits
Show all changes
28 commits
Select commit
Hold shift + click to select a range
42e960a
try
youkaichao 2fbe131
try
youkaichao 3c771be
add rlhf example
youkaichao 3866cf6
fix args
youkaichao b7d979a
fix rpc name
youkaichao 3991bfc
fix ray
youkaichao e243189
fix ray?
youkaichao a869e86
fix ray?
youkaichao f59cf2a
fix ray?
youkaichao 7881683
fix ray?
youkaichao 19ec7b9
fix ray?
youkaichao 97f0de2
fix ray?
youkaichao 8cf7670
fix ray?
youkaichao 08e4ffb
gpu allocation control
youkaichao 29f74fa
fix imports
youkaichao eab65b6
fix linter
youkaichao fac3dcc
fix linter
youkaichao 68c2a06
add tests
youkaichao 6b6a171
add tests
youkaichao e601693
fix tests
youkaichao e9851ba
update examples
youkaichao 10645f7
update examples
youkaichao 97b67c4
lint
youkaichao 9f0c1c6
lint
youkaichao 972634f
fix tests
youkaichao 34cd83d
move examples
youkaichao d757692
use docstring
youkaichao 98b417d
elif
youkaichao File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
# a simple demonstration of RLHF with VLLM. | ||
import os | ||
|
||
import ray | ||
import torch | ||
from ray.util.placement_group import placement_group | ||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy | ||
from transformers import AutoModelForCausalLM | ||
|
||
from vllm import LLM, SamplingParams, configure_as_vllm_process | ||
from vllm.utils import get_ip, get_open_port | ||
from vllm.worker.worker import Worker | ||
|
||
|
||
# recommended way to create data-plane communication | ||
# between external (train processes) and VLLM workers. | ||
def stateless_init_process_group(master_address, master_port, rank, world_size, | ||
device): | ||
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator | ||
from vllm.distributed.utils import StatelessProcessGroup | ||
pg = StatelessProcessGroup.create(host=master_address, | ||
port=master_port, | ||
rank=rank, | ||
world_size=world_size) | ||
pynccl = PyNcclCommunicator(pg, device=device) | ||
return pynccl | ||
|
||
|
||
# inference code, inherit from Worker to provide custom functions | ||
class MyWorker(Worker): | ||
|
||
def init_weight_update_group(self, master_address, master_port, | ||
rank_offset, world_size): | ||
from vllm.distributed.parallel_state import get_world_group | ||
rank = get_world_group().rank + rank_offset | ||
self.model_update_group = stateless_init_process_group( | ||
master_address, | ||
master_port, | ||
rank, | ||
world_size, | ||
self.device, | ||
) | ||
|
||
def update_weight(self, name, dtype, shape): | ||
weight = torch.empty(shape, dtype=dtype, device="cuda") | ||
self.model_update_group.broadcast(weight, | ||
src=0, | ||
stream=torch.cuda.current_stream()) | ||
|
||
self.model_runner.model.load_weights(weights=[(name, weight)]) | ||
|
||
del weight | ||
|
||
def get_weight_square_sum(self): | ||
sum_value = 0.0 | ||
for name, p in self.model_runner.model.named_parameters(): | ||
sum_value += p.square().sum().item() | ||
return sum_value | ||
|
||
|
||
class MyLLM(LLM): | ||
|
||
def __init__(self, *args, **kwargs): | ||
# stop ray from manipulating CUDA_VISIBLE_DEVICES | ||
# at the top-level | ||
del os.environ["CUDA_VISIBLE_DEVICES"] | ||
super().__init__(*args, **kwargs) | ||
|
||
|
||
# current process is a training process, and it takes 1 GPU. | ||
# important: set some common environment variables the same as vLLM workers. | ||
configure_as_vllm_process() | ||
|
||
train_model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m") | ||
train_model.to("cuda:0") | ||
|
||
# start ray with 2 GPUs | ||
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2" | ||
ray.init() | ||
|
||
pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 2) | ||
ray.get(pg_inference.ready()) | ||
scheduling_inference = PlacementGroupSchedulingStrategy( | ||
placement_group=pg_inference, | ||
placement_group_capture_child_tasks=True, | ||
placement_group_bundle_index=0, | ||
) | ||
|
||
# inferencing engine, it takes 2 GPUs. | ||
# for simplicity, we define the MyWorker class in this self-contained script. | ||
# normally, we should define the MyWorker class in a separate file and pass | ||
# the qualified name of the class to the worker_cls parameter. | ||
# here we use `enforce_eager` to reduce test time. | ||
llm = ray.remote( | ||
num_cpus=0, | ||
num_gpus=0, | ||
scheduling_strategy=scheduling_inference, | ||
)(MyLLM).remote( | ||
model="facebook/opt-125m", | ||
enforce_eager=True, | ||
worker_cls=MyWorker, | ||
tensor_parallel_size=2, | ||
distributed_executor_backend="ray", | ||
) | ||
|
||
# Generate texts from the prompts. | ||
prompts = [ | ||
"Hello, my name is", | ||
"The president of the United States is", | ||
"The capital of France is", | ||
"The future of AI is", | ||
] | ||
|
||
sampling_params = SamplingParams(temperature=0) | ||
|
||
outputs_original = ray.get(llm.generate.remote(prompts, sampling_params)) | ||
|
||
master_address = get_ip() | ||
master_port = get_open_port() | ||
|
||
# set up the connection between the training process and the inference engine. | ||
handle = llm.collective_rpc.remote("init_weight_update_group", | ||
args=(master_address, master_port, 1, 3)) | ||
model_update_group = stateless_init_process_group(master_address, master_port, | ||
0, 3, torch.device("cuda:0")) | ||
ray.get(handle) | ||
|
||
# simulate training, modify the weights of the model. | ||
for name, p in train_model.named_parameters(): | ||
p.data.zero_() | ||
|
||
# sync weight from the training process to the inference engine. | ||
for name, p in train_model.named_parameters(): | ||
handle = llm.collective_rpc.remote("update_weight", | ||
args=(name, p.dtype, p.shape)) | ||
model_update_group.broadcast(p, src=0, stream=torch.cuda.current_stream()) | ||
ray.get(handle) | ||
|
||
# check if the weights are updated. | ||
weight_square_sum_values = ray.get( | ||
llm.collective_rpc.remote("get_weight_square_sum")) | ||
for x in weight_square_sum_values: | ||
assert x == 0.0 | ||
|
||
# use the updated model to generate texts. | ||
outputs_updated = ray.get(llm.generate.remote(prompts, sampling_params)) | ||
|
||
# they should be different. | ||
for output_original, output_updated in zip(outputs_original, outputs_updated): | ||
generated_text_original = output_original.outputs[0].text | ||
generated_text_updated = output_updated.outputs[0].text | ||
assert generated_text_original != generated_text_updated |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.