Skip to content

Commit f6f81f0

Browse files
ananthsubcarmocca
andauthored
[fix] Add a cluster environment teardown to clean up environment state (#6942)
Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 7b0b6c3 commit f6f81f0

File tree

5 files changed

+28
-4
lines changed

5 files changed

+28
-4
lines changed

CHANGELOG.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
99

1010
### Added
1111

12+
- Added a `teardown` hook to `ClusterEnvironment` ([#6942](https://github.com/PyTorchLightning/pytorch-lightning/pull/6942))
13+
14+
1215
- Added utils for NaN/Inf detection for gradients and parameters ([#6834](https://github.com/PyTorchLightning/pytorch-lightning/pull/6834/))
1316

1417

@@ -196,6 +199,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
196199

197200
### Fixed
198201

202+
- Fixed incorrect removal of `WORLD_SIZE` environment variable in DDP training when launching with torch distributed/torchelastic ([#6942](https://github.com/PyTorchLightning/pytorch-lightning/pull/6942))
203+
204+
199205
- Set better defaults for `rank_zero_only.rank` when training is launched with SLURM and torchelastic:
200206
* Support SLURM and torchelastic global rank environment variables ([#5715](https://github.com/PyTorchLightning/pytorch-lightning/pull/5715))
201207
* Remove hardcoding of local rank in accelerator connector ([#6878](https://github.com/PyTorchLightning/pytorch-lightning/pull/6878))
@@ -243,7 +249,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
243249
- Fixed `--gpus` default for parser returned by `Trainer.add_argparse_args` ([#6898](https://github.com/PyTorchLightning/pytorch-lightning/pull/6898))
244250

245251

246-
- Fixed pickle error checker to now check for `pickle.PickleError` to catch all pickle errors ([#6917](https://github.com/PyTorchLightning/pytorch-lightning/pull/6917))
252+
- Fixed pickle error checker to now check for `pickle.PickleError` to catch all pickle errors ([#6917](https://github.com/PyTorchLightning/pytorch-lightning/pull/6917))
247253

248254

249255
- Fixed `AttributeError` for `require_backward_grad_sync` when running manual optimization with sharded plugin ([#6915](https://github.com/PyTorchLightning/pytorch-lightning/pull/6915))

pytorch_lightning/plugins/environments/cluster_environment.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,7 @@ def local_rank(self) -> int:
5252
@abstractmethod
5353
def node_rank(self) -> int:
5454
""" The rank (index) of the node on which the current process runs. """
55+
56+
def teardown(self) -> None:
57+
""" Clean up any state set after execution finishes. """
58+
pass

pytorch_lightning/plugins/environments/lightning_environment.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ def node_rank(self) -> int:
6868
group_rank = os.environ.get("GROUP_RANK", 0)
6969
return int(os.environ.get("NODE_RANK", group_rank))
7070

71+
def teardown(self) -> None:
72+
if "WORLD_SIZE" in os.environ:
73+
del os.environ["WORLD_SIZE"]
74+
7175

7276
def find_free_network_port() -> int:
7377
"""

pytorch_lightning/plugins/training_type/ddp.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -280,9 +280,8 @@ def pre_dispatch(self):
280280

281281
self.barrier()
282282

283-
def post_dispatch(self):
284-
if "WORLD_SIZE" in os.environ:
285-
del os.environ["WORLD_SIZE"]
283+
def post_dispatch(self) -> None:
284+
self.cluster_environment.teardown()
286285

287286
def barrier(self, *args, **kwargs):
288287
if torch_distrib.is_initialized():

tests/plugins/environments/test_lightning_environment.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,14 @@ def test_random_master_port():
5555
assert isinstance(port, int)
5656
# repeated calls do not generate a new port number
5757
assert env.master_port() == port
58+
59+
60+
@mock.patch.dict(os.environ, {
61+
"WORLD_SIZE": "1",
62+
})
63+
def test_teardown():
64+
""" Test that the GROUP_RANK substitutes NODE_RANK. """
65+
env = LightningEnvironment()
66+
assert "WORLD_SIZE" in os.environ
67+
env.teardown()
68+
assert "WORLD_SIZE" not in os.environ

0 commit comments

Comments
 (0)