Skip to content

Commit 83d74bb

Browse files
low5545awaelchlitchaton
authored
Fix reset_seed() converting the PL_SEED_WORKERS environment variable str read to bool (#10099)
Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: tchaton <[email protected]>
1 parent 9af1dd7 commit 83d74bb

File tree

3 files changed

+16
-5
lines changed

3 files changed

+16
-5
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -665,6 +665,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
665665
- Fixed creation of `dirpath` in `BaseProfiler` if it doesn't exist ([#10073](https://github.com/PyTorchLightning/pytorch-lightning/pull/10073))
666666

667667

668+
- Fixed an issue with `pl.utilities.seed.reset_seed` converting the `PL_SEED_WORKERS` environment variable to `bool` ([#10099](https://github.com/PyTorchLightning/pytorch-lightning/pull/10099))
669+
670+
671+
668672
## [1.4.9] - 2021-09-30
669673

670674
- Fixed `lr_find` to generate same results on multiple calls ([#9704](https://github.com/PyTorchLightning/pytorch-lightning/pull/9704))

pytorch_lightning/utilities/seed.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,9 @@ def reset_seed() -> None:
8888
If :func:`pytorch_lightning.utilities.seed.seed_everything` is unused, this function will do nothing.
8989
"""
9090
seed = os.environ.get("PL_GLOBAL_SEED", None)
91-
workers = os.environ.get("PL_SEED_WORKERS", False)
91+
workers = os.environ.get("PL_SEED_WORKERS", "0")
9292
if seed is not None:
93-
seed_everything(int(seed), workers=bool(workers))
93+
seed_everything(int(seed), workers=bool(int(workers)))
9494

9595

9696
def pl_worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: # pragma: no cover

tests/utilities/test_seed.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,19 @@ def test_reset_seed_no_op():
5656
assert "PL_GLOBAL_SEED" not in os.environ
5757

5858

59-
def test_reset_seed_everything():
59+
@pytest.mark.parametrize("workers", (True, False))
60+
def test_reset_seed_everything(workers):
6061
"""Test that we can reset the seed to the initial value set by seed_everything()"""
6162
assert "PL_GLOBAL_SEED" not in os.environ
62-
seed_utils.seed_everything(123)
63-
assert os.environ["PL_GLOBAL_SEED"] == "123"
63+
assert "PL_SEED_WORKERS" not in os.environ
64+
65+
seed_utils.seed_everything(123, workers)
6466
before = torch.rand(1)
67+
assert os.environ["PL_GLOBAL_SEED"] == "123"
68+
assert os.environ["PL_SEED_WORKERS"] == str(int(workers))
69+
6570
seed_utils.reset_seed()
6671
after = torch.rand(1)
72+
assert os.environ["PL_GLOBAL_SEED"] == "123"
73+
assert os.environ["PL_SEED_WORKERS"] == str(int(workers))
6774
assert torch.allclose(before, after)

0 commit comments

Comments
 (0)