Skip to content

[core] set up data parallel communication #13591

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 70 commits into from
Feb 22, 2025

Conversation

youkaichao
Copy link
Member

@youkaichao youkaichao commented Feb 20, 2025

We need to explore data parallel in many cases, e.g. in deepseek models, and moe models.

While the end-user interface is still to be designed, this PR first creates the necessary communication channel for data parallel, and leave the interface for future design.

  • In the future, as long as an external launcher can set up VLLM_DP_RANK, VLLM_DP_SIZE, VLLM_DP_MASTER_IP, VLLM_DP_MASTER_PORT, and CUDA_VISIBLE_DEVICES correctly, it will be compatible with this PR.
  • The main communication inside the worker now has DP group
  • The engine process also has a separate DP group to communicate across DP instances.

Example commands to use data parallel: torchrun --nproc-per-node=2 examples/offline_inference/data_parallel.py

Note: this PR only set up the communication channel. It is not used in the model forward pass yet. To enjoy the benefit of data parallel, especially with the combination of expert parallel, we need to:

  • Implement execute_dummy_batch when should_execute_dummy_batch == True, in engines
  • synchronize use_cuda_graph in model runner across DP groups. this is technically not necessary, but if we have some collective operations that do something different w/ and w/o cudagraph, this sync would be necessary.
  • change the MoE loading logic to shard experts in world size, instead of TP size.
  • Add some all-to-all communication before and after MoE computation to gather selection logits from DP ranks.

NOTE: I think currently PP is not really compatible with DP. This is right now quite complicated to reason about.

Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

backend,
group_name="dp")

logger.info(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

example of the rank assignment for DP=2 x TP=2:

rank 0 in world size 4 is assigned as DP rank 0, PP rank 0, TP rank 0
rank 1 in world size 4 is assigned as DP rank 0, PP rank 0, TP rank 1
rank 2 in world size 4 is assigned as DP rank 1, PP rank 0, TP rank 0
rank 3 in world size 4 is assigned as DP rank 1, PP rank 0, TP rank 1

Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

JFYI: I ran into an issue with the master port already being in use (see comment in config.py)

self.data_parallel_size = envs.VLLM_DP_SIZE
self.data_parallel_rank = envs.VLLM_DP_RANK
self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP
self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that I'm hitting issues like:

RuntimeError: The server socket has failed to listen on any local network address. port: 29500, useIpv6: 0, code: -98, name: EADDRINUSE, message: address already in use

This is true even if I change the master port with torchrun --master-port .... Currently hacking around it by changing this to self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT + 1

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's strange. I also met it once but then it disappeared.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems this disappeared when i remove torchrun in af53b4b

Comment on lines +1344 to +1345
answer = self.data_parallel_master_port
self.data_parallel_master_port += 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if the port is already being used by other services?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then it will error.

We can document and say we will use more than one port starting from the specified port. And the assumption usually should be fine.

NOTE: even if we only use the specified port, there're still chances that some other services already use that port before we start to use that port. It is unavoidable if we are running multiple services in the same host. But for cloud deployment, where each service runs in a separate container, it should be fine.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively we can just check if this port is being used using socket? So we just keep searching for the next available port

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not feasible because non-zero ranks will directly connect to the specified port, and it does not know if it is the master rank or some other services. and it also needs to wait for some time in case the master rank is not started yet.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the code in 267cd82, at least vllm's internal port usage will not conflict with the dp master ports.

Comment on lines 186 to 188
if self.should_execute_dummy_batch:
self.should_execute_dummy_batch = False
# TODO: execute a dummy batch to sync across ranks
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this is not the right place for this logic? This should be in the EngineCore's busy loop I feel.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not familiar with the engine part, can you show me where i should put it?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a bitter lesson, we need to place this logic at the top level, which is the llmengine level in offline inference.

we cannot put it in the EngineCore's busy loop, otherwise the llmengine will exit directly without checking the status of other dp ranks.

Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
@youkaichao
Copy link
Member Author

@youkaichao instead of calling dummy forward as a utility method, could we instead modify the step() method in core.py like this.. and have model runner execute_model call _dummy_run if it gets None as the scheduler output?

    def step(self) -> EngineCoreOutputs:
        """Schedule, execute, and make output."""

        if not self.scheduler.has_unfinished_requests():
            self.model_executor.execute_model(None)
            return EngineCoreOutputs(
                outputs=[], scheduler_stats=self.scheduler.make_stats())

@njhill I tried that approach as well, but didn't succeed. It needs more changes, e.g. we need to change the semantic of execute_model to define what does it mean to have None as input, and breaks several other code. I gave it up because I'm not familiar with that part of code, but feel free to have a try after this PR.

Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@youkaichao youkaichao enabled auto-merge (squash) February 22, 2025 06:44
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Feb 22, 2025
@youkaichao
Copy link
Member Author

failed tests are due to hf timeout, merging.

@youkaichao youkaichao disabled auto-merge February 22, 2025 11:28
@youkaichao youkaichao merged commit 3e472d8 into vllm-project:main Feb 22, 2025
67 of 72 checks passed
@youkaichao youkaichao deleted the manual_dp branch February 22, 2025 11:29
@youkaichao youkaichao mentioned this pull request Feb 22, 2025
if len(prompts) == 0:
# if any rank has no prompts to process,
# we need to set a placeholder prompt
prompts = ["Placeholder"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this is just an example but in practice I guess you'd want to set max_tokens to 1 for any placeholder prompts.

@lewisword
Copy link

lewisword commented Mar 14, 2025

May I ask if this feature can be used in a service-oriented way? I see from the example in examples/offline_inference/data_parallel.py that it uses an offline multi-process invocation approach. @youkaichao

@njhill
Copy link
Member

njhill commented Mar 17, 2025

@lewisword not yet, but it will be coming via #13923.

@QiuMike
Copy link

QiuMike commented May 9, 2025

@youkaichao
I run the offline examples in H20 with two nodes, each node has 8 cards.

export VLLM_DP_MASTER_IP=10.13.3.163
export GLOO_SOCKET_IFNAME=eth0
export TP_SOCKET_IFNAME=eth0

python3 examples/offline_inference/data_parallel.py --node-size 2 --node-rank 0 --master-addr 10.13.3.163 --model /home/xxxxx/DeepSeek-R1 --master-port 13345 --dp-size 2 --tp-size 8

python3 examples/offline_inference/data_parallel.py --node-size 2 --model /home/xxxxx/DeepSeek-R1/ --node-rank 1 --master-addr 10.13.3.163 --master-port 13345 --dp-size 2 --tp-size 8

DP rank 0, Prompt: 'Hello, my name is', Generated text: ' Danielle. I’m a new Master’s student in the Sustainability and Energy' [16/1933]
DP rank 0, Prompt: 'The president of the United States is', Generated text: ' the head of state and head of government of the United States, indirectly elected to'
DP rank 0, Prompt: 'The capital of France is', Generated text: ' Paris, and the three major cities are Paris, Lyon, and Marseille. France'
DP rank 0, Prompt: 'The future of AI is', Generated text: ' a topic that has been discussed and debated by experts, researchers, and enthusiasts alike'
DP rank 0, Prompt: 'Hello, my name is', Generated text: " Mr. Sato.\nLet's learn how to identify themes in literature.\nFirst of"
(EngineCore_0 pid=28846) INFO 05-08 13:09:23 [core.py:372] EngineCore exiting with signum 15
(EngineCore_0 pid=28846) Process EngineCore_0:
(EngineCore_0 pid=28846) Traceback (most recent call last):
(EngineCore_0 pid=28846) File "/home/admin/michael/vllm/vllm/v1/engine/core.py", line 394, in run_engine_core
(EngineCore_0 pid=28846) engine_core.run_busy_loop()
(EngineCore_0 pid=28846) File "/home/admin/michael/vllm/vllm/v1/engine/core.py", line 687, in run_busy_loop
(EngineCore_0 pid=28846) self.execute_dummy_batch()
(EngineCore_0 pid=28846) File "/home/admin/michael/vllm/vllm/v1/engine/core.py", line 281, in execute_dummy_batch
(EngineCore_0 pid=28846) self.model_executor.collective_rpc("execute_dummy_batch")
(EngineCore_0 pid=28846) File "/home/admin/michael/vllm/vllm/v1/executor/multiproc_executor.py", line 215, in collective_rpc
(EngineCore_0 pid=28846) result = get_response(w, dequeue_timeout)
(EngineCore_0 pid=28846) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=28846) File "/home/admin/michael/vllm/vllm/v1/executor/multiproc_executor.py", line 198, in get_response
(EngineCore_0 pid=28846) status, result = w.worker_response_mq.dequeue(
(EngineCore_0 pid=28846) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=28846) File "/home/admin/michael/vllm/vllm/distributed/device_communicators/shm_broadcast.py", line 479, in dequeue
(EngineCore_0 pid=28846) with self.acquire_read(timeout, cancel) as buf:
(EngineCore_0 pid=28846) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=28846) File "/usr/lib/python3.12/contextlib.py", line 137, in enter
(EngineCore_0 pid=28846) return next(self.gen)
(EngineCore_0 pid=28846) ^^^^^^^^^^^^^^
(EngineCore_0 pid=28846) File "/home/admin/michael/vllm/vllm/distributed/device_communicators/shm_broadcast.py", line 425, in acquire_read
(EngineCore_0 pid=28846) sched_yield()
(EngineCore_0 pid=28846) File "/home/admin/michael/vllm/vllm/distributed/device_communicators/shm_broadcast.py", line 41, in sched_yield
(EngineCore_0 pid=28846) os.sched_yield()
(EngineCore_0 pid=28846) File "/home/admin/michael/vllm/vllm/v1/engine/core.py", line 376, in signal_handler
(EngineCore_0 pid=28846) raise SystemExit()
(EngineCore_0 pid=28846) SystemExit
(EngineCore_0 pid=28846)
(EngineCore_0 pid=28846) During handling of the above exception, another exception occurred:
(EngineCore_0 pid=28846)
(EngineCore_0 pid=28846) Traceback (most recent call last):
(EngineCore_0 pid=28846) File "/usr/lib/python3.12/multiprocessing/process.py", line 314, in _bootstrap

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants