Skip to content

Commit 25e6907

Browse files
hmellorwuisawesome
authored andcommitted
Use @property and private field for data_parallel_rank_local (vllm-project#17053)
Signed-off-by: Harry Mellor <[email protected]>
1 parent d43a67e commit 25e6907

File tree

2 files changed

+18
-5
lines changed

2 files changed

+18
-5
lines changed

vllm/config.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1593,8 +1593,21 @@ class ParallelConfig:
15931593
the product of the tensor parallel size and data parallel size."""
15941594
data_parallel_rank: int = 0
15951595
"""Rank of the data parallel group."""
1596-
data_parallel_rank_local: Optional[int] = None
1597-
"""Local rank of the data parallel group, defaults to global rank."""
1596+
_data_parallel_rank_local: Optional[int] = field(default=None, init=False)
1597+
"""Private field to store the local rank of the data parallel group."""
1598+
1599+
@property
1600+
def data_parallel_rank_local(self) -> int:
1601+
"""Local rank of the data parallel group, defaults to global rank."""
1602+
if self._data_parallel_rank_local is None:
1603+
return self.data_parallel_rank
1604+
return self._data_parallel_rank_local
1605+
1606+
@data_parallel_rank_local.setter
1607+
def data_parallel_rank_local(self, value: int) -> None:
1608+
"""Set the local rank of the data parallel group."""
1609+
self._data_parallel_rank_local = value
1610+
15981611
data_parallel_master_ip: str = "127.0.0.1"
15991612
"""IP of the data parallel master."""
16001613
data_parallel_master_port: int = 29500

vllm/v1/engine/core_client.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -439,10 +439,10 @@ def _init_core_engines(
439439
) -> None:
440440

441441
# Default case - single core engine.
442-
dp_rank = vllm_config.parallel_config.data_parallel_rank
443-
local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local
444442
core_engine = new_core_engine(
445-
dp_rank, local_dp_rank if local_dp_rank is not None else dp_rank)
443+
vllm_config.parallel_config.data_parallel_rank,
444+
vllm_config.parallel_config.data_parallel_rank_local,
445+
)
446446
core_engines.append(core_engine)
447447
self.core_engine = core_engine
448448

0 commit comments

Comments
 (0)