|
24 | 24 |
|
25 | 25 | from pl_examples.bug_report_model import RandomDataset
|
26 | 26 | from pytorch_lightning import LightningModule, Trainer
|
27 |
| -from pytorch_lightning.callbacks import ModelCheckpoint |
| 27 | +from pytorch_lightning.callbacks import Callback, ModelCheckpoint |
28 | 28 | from pytorch_lightning.loops import Loop, TrainingBatchLoop
|
29 | 29 | from pytorch_lightning.trainer.progress import BaseProgress
|
30 | 30 | from tests.helpers import BoringModel
|
@@ -907,8 +907,10 @@ def val_dataloader(self):
|
907 | 907 |
|
908 | 908 |
|
909 | 909 | @RunIf(min_torch="1.8.0")
|
910 |
| -@pytest.mark.parametrize("persistent_workers", (False, True)) |
911 |
| -def test_workers_are_shutdown(tmpdir, persistent_workers): |
| 910 | +@pytest.mark.parametrize("should_fail", [False, True]) |
| 911 | +# False is de-activated due to slowness |
| 912 | +@pytest.mark.parametrize("persistent_workers", [True]) |
| 913 | +def test_workers_are_shutdown(tmpdir, should_fail, persistent_workers): |
912 | 914 | # `num_workers == 1` uses `_MultiProcessingDataLoaderIter`
|
913 | 915 | # `persistent_workers` makes sure `self._iterator` gets set on the `DataLoader` instance
|
914 | 916 |
|
@@ -936,12 +938,30 @@ def _get_iterator(self):
|
936 | 938 | train_dataloader = TestDataLoader(RandomDataset(32, 64), num_workers=1, persistent_workers=persistent_workers)
|
937 | 939 | val_dataloader = TestDataLoader(RandomDataset(32, 64), num_workers=1, persistent_workers=persistent_workers)
|
938 | 940 |
|
| 941 | + class TestCallback(Callback): |
| 942 | + def on_train_epoch_end(self, trainer, *_): |
| 943 | + if trainer.current_epoch == 1: |
| 944 | + raise CustomException |
| 945 | + |
939 | 946 | max_epochs = 3
|
| 947 | + |
940 | 948 | model = BoringModel()
|
941 |
| - trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, max_epochs=max_epochs) |
942 |
| - trainer.fit(model, train_dataloader, val_dataloader) |
943 |
| - assert train_dataloader.count_shutdown_workers == (2 if persistent_workers else max_epochs) |
| 949 | + trainer = Trainer( |
| 950 | + default_root_dir=tmpdir, |
| 951 | + limit_train_batches=2, |
| 952 | + limit_val_batches=2, |
| 953 | + max_epochs=max_epochs, |
| 954 | + callbacks=TestCallback() if should_fail else None, |
| 955 | + ) |
| 956 | + |
| 957 | + if should_fail: |
| 958 | + with pytest.raises(CustomException): |
| 959 | + trainer.fit(model, train_dataloader, val_dataloader) |
| 960 | + else: |
| 961 | + trainer.fit(model, train_dataloader, val_dataloader) |
| 962 | + |
| 963 | + assert train_dataloader.count_shutdown_workers == 2 if should_fail else (2 if persistent_workers else max_epochs) |
944 | 964 | # on sanity checking end, the workers are being deleted too.
|
945 |
| - assert val_dataloader.count_shutdown_workers == (2 if persistent_workers else max_epochs + 1) |
| 965 | + assert val_dataloader.count_shutdown_workers == 2 if persistent_workers else (3 if should_fail else max_epochs + 1) |
946 | 966 | assert train_dataloader._iterator is None
|
947 | 967 | assert val_dataloader._iterator is None
|
0 commit comments