Skip to content
This repository was archived by the owner on Sep 28, 2022. It is now read-only.

Commit dffa3a1

Browse files
tchatonRaalsky
authored andcommitted
shutdown workers on failure (Lightning-AI#10463)
1 parent 3bce41a commit dffa3a1

File tree

3 files changed

+32
-7
lines changed

3 files changed

+32
-7
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
131131
- Fixed `CombinedLoader` and `max_size_cycle` didn't receive a `DistributedSampler` ([#10374](https://github.com/PyTorchLightning/pytorch-lightning/issues/10374))
132132

133133

134+
- Fixed an issue that prevented the Trainer to shutdown workers when execution is interrupted due to failure([#10463](https://github.com/PyTorchLightning/pytorch-lightning/issues/10463))
135+
136+
134137
- Squeeze the early stopping monitor to remove empty tensor dimensions ([#10461](https://github.com/PyTorchLightning/pytorch-lightning/issues/10461))
135138

136139

pytorch_lightning/trainer/trainer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -694,6 +694,8 @@ def _call_and_handle_interrupt(self, trainer_fn: Callable, *args: Any, **kwargs:
694694
# reset bookkeeping
695695
self.state.stage = None
696696
self.on_exception(exception)
697+
# shutdown workers
698+
self._data_connector.teardown()
697699
raise
698700

699701
def fit(

tests/loops/test_loops.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
from pl_examples.bug_report_model import RandomDataset
2626
from pytorch_lightning import LightningModule, Trainer
27-
from pytorch_lightning.callbacks import ModelCheckpoint
27+
from pytorch_lightning.callbacks import Callback, ModelCheckpoint
2828
from pytorch_lightning.loops import Loop, TrainingBatchLoop
2929
from pytorch_lightning.trainer.progress import BaseProgress
3030
from tests.helpers import BoringModel
@@ -907,8 +907,10 @@ def val_dataloader(self):
907907

908908

909909
@RunIf(min_torch="1.8.0")
910-
@pytest.mark.parametrize("persistent_workers", (False, True))
911-
def test_workers_are_shutdown(tmpdir, persistent_workers):
910+
@pytest.mark.parametrize("should_fail", [False, True])
911+
# False is de-activated due to slowness
912+
@pytest.mark.parametrize("persistent_workers", [True])
913+
def test_workers_are_shutdown(tmpdir, should_fail, persistent_workers):
912914
# `num_workers == 1` uses `_MultiProcessingDataLoaderIter`
913915
# `persistent_workers` makes sure `self._iterator` gets set on the `DataLoader` instance
914916

@@ -936,12 +938,30 @@ def _get_iterator(self):
936938
train_dataloader = TestDataLoader(RandomDataset(32, 64), num_workers=1, persistent_workers=persistent_workers)
937939
val_dataloader = TestDataLoader(RandomDataset(32, 64), num_workers=1, persistent_workers=persistent_workers)
938940

941+
class TestCallback(Callback):
942+
def on_train_epoch_end(self, trainer, *_):
943+
if trainer.current_epoch == 1:
944+
raise CustomException
945+
939946
max_epochs = 3
947+
940948
model = BoringModel()
941-
trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, max_epochs=max_epochs)
942-
trainer.fit(model, train_dataloader, val_dataloader)
943-
assert train_dataloader.count_shutdown_workers == (2 if persistent_workers else max_epochs)
949+
trainer = Trainer(
950+
default_root_dir=tmpdir,
951+
limit_train_batches=2,
952+
limit_val_batches=2,
953+
max_epochs=max_epochs,
954+
callbacks=TestCallback() if should_fail else None,
955+
)
956+
957+
if should_fail:
958+
with pytest.raises(CustomException):
959+
trainer.fit(model, train_dataloader, val_dataloader)
960+
else:
961+
trainer.fit(model, train_dataloader, val_dataloader)
962+
963+
assert train_dataloader.count_shutdown_workers == 2 if should_fail else (2 if persistent_workers else max_epochs)
944964
# on sanity checking end, the workers are being deleted too.
945-
assert val_dataloader.count_shutdown_workers == (2 if persistent_workers else max_epochs + 1)
965+
assert val_dataloader.count_shutdown_workers == 2 if persistent_workers else (3 if should_fail else max_epochs + 1)
946966
assert train_dataloader._iterator is None
947967
assert val_dataloader._iterator is None

0 commit comments

Comments
 (0)