Skip to content

Commit caed77f

Browse files
Refactor TorchElasticEnvironment.detect to use torch.distributed.is_torchelastic_launched (#12376)
* Refactor TorchElasticEnvironment.detect to use native utility from torch.distributed * fix version and tests * fix version * Update tests/accelerators/test_accelerator_connector.py Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]>
1 parent fe940e1 commit caed77f

File tree

6 files changed

+49
-4
lines changed

6 files changed

+49
-4
lines changed

pytorch_lightning/plugins/environments/torchelastic_environment.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@
1515
import logging
1616
import os
1717

18+
import torch.distributed
19+
1820
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
21+
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_9_1
1922
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn
2023

2124
log = logging.getLogger(__name__)
@@ -58,6 +61,8 @@ def main_port(self) -> int:
5861
@staticmethod
5962
def detect() -> bool:
6063
"""Returns ``True`` if the current process was launched using the torchelastic command."""
64+
if _TORCH_GREATER_EQUAL_1_9_1:
65+
return torch.distributed.is_torchelastic_launched()
6166
required_env_vars = {"RANK", "GROUP_RANK", "LOCAL_RANK", "LOCAL_WORLD_SIZE"}
6267
return required_env_vars.issubset(os.environ.keys())
6368

pytorch_lightning/utilities/imports.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def _compare_version(package: str, op: Callable, version: str, use_base_version:
9292
_TORCH_GREATER_EQUAL_1_8 = _compare_version("torch", operator.ge, "1.8.0")
9393
_TORCH_GREATER_EQUAL_1_8_1 = _compare_version("torch", operator.ge, "1.8.1")
9494
_TORCH_GREATER_EQUAL_1_9 = _compare_version("torch", operator.ge, "1.9.0")
95+
_TORCH_GREATER_EQUAL_1_9_1 = _compare_version("torch", operator.ge, "1.9.1")
9596
_TORCH_GREATER_EQUAL_1_10 = _compare_version("torch", operator.ge, "1.10.0")
9697
_TORCH_GREATER_EQUAL_1_11 = _compare_version("torch", operator.ge, "1.11.0")
9798

tests/accelerators/test_accelerator_connector.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ def test_accelerator_choice_ddp2_slurm(*_):
147147
"RANK": "1",
148148
"LOCAL_RANK": "1",
149149
"GROUP_RANK": "0",
150+
"TORCHELASTIC_RUN_ID": "1", # present for torch >= 1.9.1
150151
},
151152
)
152153
@mock.patch("torch.cuda.set_device")
@@ -172,6 +173,7 @@ def test_accelerator_choice_ddp_te(*_):
172173
"RANK": "1",
173174
"LOCAL_RANK": "1",
174175
"GROUP_RANK": "0",
176+
"TORCHELASTIC_RUN_ID": "1",
175177
},
176178
)
177179
@mock.patch("torch.cuda.set_device")
@@ -189,7 +191,15 @@ def test_accelerator_choice_ddp2_te(*_):
189191

190192

191193
@mock.patch.dict(
192-
os.environ, {"WORLD_SIZE": "2", "LOCAL_WORLD_SIZE": "2", "RANK": "1", "LOCAL_RANK": "1", "GROUP_RANK": "0"}
194+
os.environ,
195+
{
196+
"WORLD_SIZE": "2",
197+
"LOCAL_WORLD_SIZE": "2",
198+
"RANK": "1",
199+
"LOCAL_RANK": "1",
200+
"GROUP_RANK": "0",
201+
"TORCHELASTIC_RUN_ID": "1",
202+
},
193203
)
194204
@mock.patch("torch.cuda.device_count", return_value=0)
195205
@mock.patch("pytorch_lightning.strategies.DDPStrategy.setup_distributed", autospec=True)
@@ -786,6 +796,7 @@ def test_strategy_choice_ddp2_slurm(
786796
"RANK": "1",
787797
"LOCAL_RANK": "1",
788798
"GROUP_RANK": "0",
799+
"TORCHELASTIC_RUN_ID": "1",
789800
},
790801
)
791802
@mock.patch("torch.cuda.set_device")
@@ -810,6 +821,7 @@ def test_strategy_choice_ddp_te(*_):
810821
"RANK": "1",
811822
"LOCAL_RANK": "1",
812823
"GROUP_RANK": "0",
824+
"TORCHELASTIC_RUN_ID": "1",
813825
},
814826
)
815827
@mock.patch("torch.cuda.set_device")
@@ -826,7 +838,15 @@ def test_strategy_choice_ddp2_te(*_):
826838

827839

828840
@mock.patch.dict(
829-
os.environ, {"WORLD_SIZE": "2", "LOCAL_WORLD_SIZE": "2", "RANK": "1", "LOCAL_RANK": "1", "GROUP_RANK": "0"}
841+
os.environ,
842+
{
843+
"WORLD_SIZE": "2",
844+
"LOCAL_WORLD_SIZE": "2",
845+
"RANK": "1",
846+
"LOCAL_RANK": "1",
847+
"GROUP_RANK": "0",
848+
"TORCHELASTIC_RUN_ID": "1",
849+
},
830850
)
831851
@mock.patch("torch.cuda.device_count", return_value=0)
832852
@mock.patch("pytorch_lightning.strategies.DDPStrategy.setup_distributed", autospec=True)

tests/models/test_gpu.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ def test_parse_gpu_returns_none_when_no_devices_are_available(mocked_device_coun
205205
"RANK": "3",
206206
"WORLD_SIZE": "4",
207207
"LOCAL_WORLD_SIZE": "2",
208+
"TORCHELASTIC_RUN_ID": "1",
208209
},
209210
)
210211
@mock.patch("torch.cuda.device_count", return_value=1)

tests/plugins/environments/test_torchelastic_environment.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import pytest
1919

2020
from pytorch_lightning.plugins.environments import TorchElasticEnvironment
21+
from tests.helpers.runif import RunIf
2122

2223

2324
@mock.patch.dict(os.environ, {})
@@ -70,8 +71,9 @@ def test_attributes_from_environment_variables(caplog):
7071
assert "setting world size is not allowed" in caplog.text
7172

7273

73-
def test_detect():
74-
"""Test the detection of a torchelastic environment configuration."""
74+
@RunIf(max_torch="1.9.0")
75+
def test_detect_before_1_9_1():
76+
"""Test the detection of a torchelastic environment configuration before 1.9.1."""
7577
with mock.patch.dict(os.environ, {}):
7678
assert not TorchElasticEnvironment.detect()
7779

@@ -85,3 +87,18 @@ def test_detect():
8587
},
8688
):
8789
assert TorchElasticEnvironment.detect()
90+
91+
92+
@RunIf(min_torch="1.9.1")
93+
def test_detect_after_1_9_1():
94+
"""Test the detection of a torchelastic environment configuration after 1.9.1."""
95+
with mock.patch.dict(os.environ, {}):
96+
assert not TorchElasticEnvironment.detect()
97+
98+
with mock.patch.dict(
99+
os.environ,
100+
{
101+
"TORCHELASTIC_RUN_ID": "",
102+
},
103+
):
104+
assert TorchElasticEnvironment.detect()

tests/plugins/test_cluster_integration.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def environment_combinations():
4949
"RANK": "3",
5050
"WORLD_SIZE": "4",
5151
"LOCAL_WORLD_SIZE": "2",
52+
"TORCHELASTIC_RUN_ID": "1",
5253
}
5354
environment = TorchElasticEnvironment()
5455
yield environment, variables, expected

0 commit comments

Comments
 (0)