Skip to content

Commit c55cbb8

Browse files
committed
[V1] Support DP with Ray
Signed-off-by: Rui Qiao <[email protected]>
1 parent 34c5eb9 commit c55cbb8

File tree

8 files changed

+536
-15
lines changed

8 files changed

+536
-15
lines changed

tests/v1/test_async_llm_dp.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@
1919
model="ibm-research/PowerMoE-3b",
2020
enforce_eager=True,
2121
disable_log_requests=True,
22-
tensor_parallel_size=int(os.getenv("TP_SIZE", 1)),
22+
tensor_parallel_size=int(os.getenv("TP_SIZE", 2)),
2323
data_parallel_size=int(os.getenv("DP_SIZE", 2)),
24+
data_parallel_address="172.31.15.128",
2425
)
2526

2627
if not current_platform.supports_v1(engine_args.create_model_config()):
@@ -59,14 +60,22 @@ async def generate(engine: AsyncLLM,
5960

6061

6162
@pytest.mark.parametrize(
62-
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
63+
"output_kind",
64+
[
65+
RequestOutputKind.DELTA,
66+
# RequestOutputKind.FINAL_ONLY,
67+
],
68+
)
69+
@pytest.mark.parametrize("data_parallel_backend", ["ray"])
6370
@pytest.mark.asyncio
64-
async def test_load(output_kind: RequestOutputKind):
71+
async def test_load(output_kind: RequestOutputKind,
72+
data_parallel_backend: str):
6573

6674
with ExitStack() as after:
6775

6876
prompt = "This is a test of data parallel"
6977

78+
engine_args.data_parallel_backend = data_parallel_backend
7079
engine = AsyncLLM.from_engine_args(engine_args)
7180
after.callback(engine.shutdown)
7281

@@ -82,7 +91,6 @@ async def test_load(output_kind: RequestOutputKind):
8291
asyncio.create_task(
8392
generate(engine, request_id, prompt, output_kind,
8493
NUM_EXPECTED_TOKENS)))
85-
8694
# Confirm that we got all the EXPECTED tokens from the requests.
8795
done, pending = await asyncio.wait(tasks,
8896
return_when=asyncio.FIRST_EXCEPTION)

vllm/config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1693,6 +1693,8 @@ class ParallelConfig:
16931693
"""Port for data parallel messaging."""
16941694
data_parallel_master_port: int = 29500
16951695
"""Port of the data parallel master."""
1696+
data_parallel_backend: str = "mp"
1697+
"""Backend to use for data parallel, either "mp" or "ray"."""
16961698
enable_expert_parallel: bool = False
16971699
"""Use expert parallelism instead of tensor parallelism for MoE layers."""
16981700
max_parallel_loading_workers: Optional[int] = None
@@ -1856,6 +1858,10 @@ def __post_init__(self) -> None:
18561858
"please install Ray with `pip install "
18571859
"ray`.") from ray_utils.ray_import_err
18581860
backend = "ray"
1861+
elif self.data_parallel_backend == "ray":
1862+
logger.info("Using ray distributed inference because "
1863+
"data_parallel_backend is ray")
1864+
backend = "ray"
18591865
elif ray_found:
18601866
if self.placement_group:
18611867
backend = "ray"

vllm/engine/arg_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,7 @@ class EngineArgs:
290290
data_parallel_size_local: Optional[int] = None
291291
data_parallel_address: Optional[str] = None
292292
data_parallel_rpc_port: Optional[int] = None
293+
data_parallel_backend: str = ParallelConfig.data_parallel_backend
293294
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
294295
max_parallel_loading_workers: Optional[
295296
int] = ParallelConfig.max_parallel_loading_workers
@@ -618,6 +619,11 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
618619
type=int,
619620
help='Port for data parallel RPC '
620621
'communication.')
622+
parallel_group.add_argument('--data-parallel-backend',
623+
'-dpb',
624+
type=str,
625+
help='Backend for data parallel, either '
626+
'"mp" or "ray".')
621627
parallel_group.add_argument(
622628
"--enable-expert-parallel",
623629
**parallel_kwargs["enable_expert_parallel"])
@@ -1058,13 +1064,16 @@ def create_engine_config(
10581064
self.data_parallel_rpc_port
10591065
is not None) else ParallelConfig.data_parallel_rpc_port
10601066

1067+
data_parallel_backend = self.data_parallel_backend
1068+
10611069
parallel_config = ParallelConfig(
10621070
pipeline_parallel_size=self.pipeline_parallel_size,
10631071
tensor_parallel_size=self.tensor_parallel_size,
10641072
data_parallel_size=self.data_parallel_size,
10651073
data_parallel_size_local=data_parallel_size_local,
10661074
data_parallel_master_ip=data_parallel_address,
10671075
data_parallel_rpc_port=data_parallel_rpc_port,
1076+
data_parallel_backend=data_parallel_backend,
10681077
enable_expert_parallel=self.enable_expert_parallel,
10691078
max_parallel_loading_workers=self.max_parallel_loading_workers,
10701079
disable_custom_all_reduce=self.disable_custom_all_reduce,

vllm/entrypoints/cli/serve.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@
2828
from vllm.v1.executor.abstract import Executor
2929
from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus
3030
from vllm.v1.utils import (APIServerProcessManager, CoreEngine,
31-
get_engine_client_zmq_addr,
31+
CoreEngineActorManager, get_engine_client_zmq_addr,
3232
wait_for_completion_or_failure,
33-
wait_for_engine_startup)
33+
wait_for_engine_startup, wait_for_ray_engine_actors)
3434

3535
logger = init_logger(__name__)
3636

@@ -221,6 +221,34 @@ def run_multi_api_server(args: argparse.Namespace):
221221
logger.info("Started DP Coordinator process (PID: %d)",
222222
coordinator.proc.pid)
223223

224+
if parallel_config.data_parallel_backend == "ray":
225+
logger.info("Starting ray-based data parallel backend")
226+
227+
engine_actor_manager = CoreEngineActorManager(
228+
local_engine_count=local_engine_count,
229+
start_index=args.data_parallel_start_rank,
230+
local_start_index=0,
231+
vllm_config=vllm_config,
232+
addresses=addresses,
233+
executor_class=Executor.get_class(vllm_config),
234+
log_stats=not engine_args.disable_log_stats,
235+
)
236+
# Start API servers using the manager
237+
api_server_manager = APIServerProcessManager(
238+
target_server_fn=run_api_server_worker,
239+
listen_address=listen_address,
240+
sock=sock,
241+
args=args,
242+
num_servers=num_api_servers,
243+
input_addresses=input_addresses,
244+
output_addresses=output_addresses,
245+
stats_update_address=stats_update_address)
246+
247+
wait_for_ray_engine_actors(api_server_manager=api_server_manager,
248+
engine_actor_manager=engine_actor_manager,
249+
coordinator=coordinator)
250+
return
251+
224252
handshake_address = get_engine_client_zmq_addr(
225253
local_only, host, parallel_config.data_parallel_rpc_port)
226254

vllm/v1/engine/async_llm.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@
2525
from vllm.usage.usage_lib import UsageContext
2626
from vllm.utils import Device, cdiv
2727
from vllm.v1.engine import EngineCoreRequest
28-
from vllm.v1.engine.core_client import AsyncMPClient, DPAsyncMPClient
28+
from vllm.v1.engine.core_client import (AsyncMPClient, DPAsyncMPClient,
29+
RayDPClient)
2930
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
3031
from vllm.v1.engine.output_processor import (OutputProcessor,
3132
RequestOutputCollector)
@@ -114,9 +115,15 @@ def __init__(
114115
log_stats=self.log_stats)
115116

116117
# EngineCore (starts the engine in background process).
117-
core_client_class = AsyncMPClient if (
118-
vllm_config.parallel_config.data_parallel_size
119-
== 1) else DPAsyncMPClient
118+
core_client_class: Union[type[RayDPClient], type[DPAsyncMPClient],
119+
type[AsyncMPClient]]
120+
if vllm_config.parallel_config.data_parallel_size > 1:
121+
if vllm_config.parallel_config.data_parallel_backend == "ray":
122+
core_client_class = RayDPClient
123+
else:
124+
core_client_class = DPAsyncMPClient
125+
else:
126+
core_client_class = AsyncMPClient
120127

121128
self.engine_core = core_client_class(
122129
vllm_config=vllm_config,

vllm/v1/engine/core.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -866,3 +866,108 @@ def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:
866866

867867
return ParallelConfig.has_unfinished_dp(self.dp_group,
868868
local_unfinished)
869+
870+
871+
class DPEngineCoreActor(DPEngineCoreProc):
872+
"""
873+
Ray Actor for running EngineCore in a data parallel context
874+
"""
875+
876+
def __init__(
877+
self,
878+
vllm_config: VllmConfig,
879+
on_head_node: bool,
880+
addresses,
881+
executor_class: type[Executor],
882+
log_stats: bool,
883+
engine_index: int = 0,
884+
dp_rank: int = 0,
885+
local_dp_rank: int = 0,
886+
):
887+
# TODO(rui): improve shutdown handling
888+
889+
# Ensure we can serialize transformer config after spawning
890+
maybe_register_config_serialize_by_value()
891+
892+
parallel_config: ParallelConfig = vllm_config.parallel_config
893+
assert parallel_config.data_parallel_size > 1 or dp_rank > 0
894+
# Set data parallel rank for this engine process.
895+
parallel_config.data_parallel_rank = dp_rank
896+
parallel_config.data_parallel_rank_local = local_dp_rank
897+
898+
input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]()
899+
900+
executor_fail_callback = lambda: input_queue.put_nowait(
901+
(EngineCoreRequestType.EXECUTOR_FAILED, b''))
902+
903+
input_addresses: list[str] = addresses["input_addresses"]
904+
output_addresses: list[str] = addresses["output_addresses"]
905+
coord_in_addr: Optional[str] = addresses.get("coord_in_address")
906+
coord_out_addr: Optional[str] = addresses.get("coord_out_address")
907+
self.client_count = len(output_addresses)
908+
self.coordinator = coord_out_addr is not None
909+
910+
# Ray sets CUDA_VISIBLE_DEVICES to empty string,
911+
# we clean this up to be able to properly initialize
912+
# data parallel groups.
913+
del os.environ['CUDA_VISIBLE_DEVICES']
914+
# Set up data parallel environment.
915+
self._init_data_parallel(vllm_config)
916+
917+
# Counts forward-passes of the model so that we can synchronize
918+
# finished with DP peers every N steps.
919+
self.counter = 0
920+
self.current_wave = 0
921+
922+
# Initialize engine core and model.
923+
EngineCore.__init__(self, vllm_config, executor_class, log_stats,
924+
executor_fail_callback)
925+
926+
self.engine_index = engine_index
927+
self.step_fn = (self.step if self.batch_queue is None else
928+
self.step_with_batch_queue)
929+
self.engines_running = False
930+
self.last_counts = (0, 0)
931+
932+
# Background Threads and Queues for IO. These enable us to
933+
# overlap ZMQ socket IO with GPU since they release the GIL,
934+
# and to overlap some serialization/deserialization with the
935+
# model forward pass.
936+
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
937+
self.input_queue = input_queue
938+
self.output_queue = queue.Queue[Union[tuple[int, EngineCoreOutputs],
939+
bytes]]()
940+
identity = engine_index.to_bytes(length=2, byteorder="little")
941+
threading.Thread(target=self.process_input_sockets,
942+
args=(input_addresses, coord_in_addr, identity),
943+
daemon=True).start()
944+
self.output_thread = threading.Thread(
945+
target=self.process_output_sockets,
946+
args=(output_addresses, coord_out_addr, engine_index),
947+
daemon=True)
948+
self.output_thread.start()
949+
950+
def wait_for_init(self):
951+
"""
952+
Wait until the engine core is initialized.
953+
954+
This is just an empty method. When ray.get() on this method
955+
(or any other method of the actor) returns, it is guaranteed
956+
that actor creation (i.e., __init__) is complete.
957+
"""
958+
pass
959+
960+
def run(self):
961+
"""
962+
Run the engine core busy loop.
963+
"""
964+
try:
965+
self.run_busy_loop()
966+
except SystemExit:
967+
logger.debug("EngineCore exiting.")
968+
raise
969+
except Exception:
970+
logger.exception("EngineCore encountered a fatal error.")
971+
raise
972+
finally:
973+
self.shutdown()

0 commit comments

Comments
 (0)