Skip to content

Commit b427fdf

Browse files
awaelchlitchaton
andcommitted
Resolve workers being forcelly deleted with persistent_workers=True (#10434)
Co-authored-by: Thomas Chaton <[email protected]>
1 parent aae797f commit b427fdf

File tree

3 files changed

+16
-9
lines changed

3 files changed

+16
-9
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1818
- Fixed deadlocks for distributed training with `RichProgressBar` ([#10428](https://github.com/PyTorchLightning/pytorch-lightning/pull/10428))
1919
- Fixed an issue where the model wrapper in Lite converted non-floating point tensors to float ([#10429](https://github.com/PyTorchLightning/pytorch-lightning/pull/10429))
2020
- Fixed an issue with inferring the dataset type in fault-tolerant training ([#10432](https://github.com/PyTorchLightning/pytorch-lightning/pull/10432))
21+
- Fixed dataloader workers with `persistent_workers` being deleted on every iteration ([#10434](https://github.com/PyTorchLightning/pytorch-lightning/pull/10434))
2122

2223

2324
## [1.5.0] - 2021-11-02

pytorch_lightning/utilities/fetching.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -206,15 +206,15 @@ def reset(self) -> None:
206206
self.batches: List = []
207207
self.fetched: int = 0
208208
self.done: bool = False
209+
210+
def teardown(self) -> None:
211+
self.reset()
209212
if isinstance(self.dataloader, CombinedLoader):
210213
self.dataloader.reset()
211214
if isinstance(self.dataloader, DataLoader):
212215
CombinedLoader._shutdown_workers_and_reset_iterator(self.dataloader)
213216
self.dataloader_iter = None
214217

215-
def teardown(self) -> None:
216-
self.reset()
217-
218218

219219
class DataFetcher(AbstractDataFetcher):
220220

tests/loops/test_loops.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -912,21 +912,25 @@ def val_dataloader(self):
912912

913913

914914
@RunIf(min_torch="1.8.0")
915-
@pytest.mark.parametrize("persistent_workers", (True, False))
915+
@pytest.mark.parametrize("persistent_workers", (False, True))
916916
def test_workers_are_shutdown(tmpdir, persistent_workers):
917917
# `num_workers == 1` uses `_MultiProcessingDataLoaderIter`
918918
# `persistent_workers` makes sure `self._iterator` gets set on the `DataLoader` instance
919919

920920
class _TestMultiProcessingDataLoaderIter(_MultiProcessingDataLoaderIter):
921-
def __init__(self, *args, dataloader: DataLoader, **kwargs):
921+
def __init__(self, *args, dataloader, **kwargs):
922922
super().__init__(*args, **kwargs)
923923
self.dataloader = dataloader
924924

925925
def _shutdown_workers(self):
926-
setattr(self.dataloader, "has_shutdown_workers", True)
926+
self.dataloader.count_shutdown_workers += 1
927927
super()._shutdown_workers()
928928

929929
class TestDataLoader(DataLoader):
930+
def __init__(self, *args, **kwargs):
931+
super().__init__(*args, **kwargs)
932+
self.count_shutdown_workers = 0
933+
930934
def _get_iterator(self):
931935
if self.num_workers == 0:
932936
return super()._get_iterator()
@@ -937,10 +941,12 @@ def _get_iterator(self):
937941
train_dataloader = TestDataLoader(RandomDataset(32, 64), num_workers=1, persistent_workers=persistent_workers)
938942
val_dataloader = TestDataLoader(RandomDataset(32, 64), num_workers=1, persistent_workers=persistent_workers)
939943

944+
max_epochs = 3
940945
model = BoringModel()
941-
trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, max_epochs=2)
946+
trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, max_epochs=max_epochs)
942947
trainer.fit(model, train_dataloader, val_dataloader)
943-
assert train_dataloader.has_shutdown_workers
944-
assert val_dataloader.has_shutdown_workers
948+
assert train_dataloader.count_shutdown_workers == (2 if persistent_workers else max_epochs)
949+
# on sanity checking end, the workers are being deleted too.
950+
assert val_dataloader.count_shutdown_workers == (2 if persistent_workers else max_epochs + 1)
945951
assert train_dataloader._iterator is None
946952
assert val_dataloader._iterator is None

0 commit comments

Comments
 (0)