Skip to content

Commit 064a94d

Browse files
committed
fix
Signed-off-by: Rui Qiao <[email protected]>
1 parent 32330da commit 064a94d

File tree

4 files changed

+25
-7
lines changed

4 files changed

+25
-7
lines changed

tests/v1/test_async_llm_dp.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@
1919
model="ibm-research/PowerMoE-3b",
2020
enforce_eager=True,
2121
disable_log_requests=True,
22-
tensor_parallel_size=int(os.getenv("TP_SIZE", 1)),
22+
tensor_parallel_size=int(os.getenv("TP_SIZE", 2)),
2323
data_parallel_size=int(os.getenv("DP_SIZE", 2)),
24+
data_parallel_address="172.31.15.128",
2425
)
2526

2627
if not current_platform.supports_v1(engine_args.create_model_config()):
@@ -62,10 +63,10 @@ async def generate(engine: AsyncLLM,
6263
"output_kind",
6364
[
6465
RequestOutputKind.DELTA,
65-
RequestOutputKind.FINAL_ONLY,
66+
# RequestOutputKind.FINAL_ONLY,
6667
],
6768
)
68-
@pytest.mark.parametrize("data_parallel_backend", ["mp", "ray"])
69+
@pytest.mark.parametrize("data_parallel_backend", ["ray"])
6970
@pytest.mark.asyncio
7071
async def test_load(output_kind: RequestOutputKind,
7172
data_parallel_backend: str):

vllm/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1847,6 +1847,10 @@ def __post_init__(self) -> None:
18471847
"please install Ray with `pip install "
18481848
"ray`.") from ray_utils.ray_import_err
18491849
backend = "ray"
1850+
elif self.data_parallel_backend == "ray":
1851+
logger.info("Using ray distributed inference because "
1852+
"data_parallel_backend is ray")
1853+
backend = "ray"
18501854
elif ray_found:
18511855
if self.placement_group:
18521856
backend = "ray"

vllm/v1/engine/core.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -904,10 +904,10 @@ def __init__(
904904
executor_fail_callback = lambda: input_queue.put_nowait(
905905
(EngineCoreRequestType.EXECUTOR_FAILED, b''))
906906

907-
input_addresses: list[str] = addresses["input_address"]
908-
output_addresses: list[str] = addresses["output_address"]
909-
coord_in_addr: Optional[str] = addresses.get("coord_in_addr")
910-
coord_out_addr: Optional[str] = addresses.get("coord_out_addr")
907+
input_addresses: list[str] = addresses["input_addresses"]
908+
output_addresses: list[str] = addresses["output_addresses"]
909+
coord_in_addr: Optional[str] = addresses.get("coord_in_address")
910+
coord_out_addr: Optional[str] = addresses.get("coord_out_address")
911911
self.client_count = len(output_addresses)
912912
self.coordinator = coord_out_addr is not None
913913

@@ -921,14 +921,17 @@ def __init__(
921921
# Counts forward-passes of the model so that we can synchronize
922922
# finished with DP peers every N steps.
923923
self.counter = 0
924+
self.current_wave = 0
924925

925926
# Initialize engine core and model.
926927
EngineCore.__init__(self, vllm_config, executor_class, log_stats,
927928
executor_fail_callback)
928929

930+
self.engine_index = engine_index
929931
self.step_fn = (self.step if self.batch_queue is None else
930932
self.step_with_batch_queue)
931933
self.engines_running = False
934+
self.last_counts = (0, 0)
932935

933936
# Background Threads and Queues for IO. These enable us to
934937
# overlap ZMQ socket IO with GPU since they release the GIL,

vllm/v1/engine/core_client.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1071,7 +1071,9 @@ def __init__(
10711071
executor_class: type[Executor],
10721072
log_stats: bool,
10731073
client_addresses: Optional[dict[str, str]] = None,
1074+
client_index: int = 0,
10741075
):
1076+
self.client_index = client_index
10751077
self.current_wave = 0
10761078
self.engines_running = False
10771079
self.reqs_in_flight: dict[str, CoreEngine] = {}
@@ -1085,6 +1087,14 @@ def __init__(
10851087
sync_ctx = zmq.Context(io_threads=2)
10861088
self.ctx = zmq.asyncio.Context(sync_ctx)
10871089

1090+
# List of [waiting, running] pair per engine.
1091+
self.lb_engines: list[list[int]] = []
1092+
self.first_req_sock_addr = get_open_zmq_inproc_path()
1093+
self.first_req_send_socket = make_zmq_socket(self.ctx,
1094+
self.first_req_sock_addr,
1095+
zmq.PAIR,
1096+
bind=True)
1097+
10881098
# This will ensure resources created so far are closed
10891099
# when the client is garbage collected, even if an
10901100
# exception is raised mid-construction.

0 commit comments

Comments
 (0)