|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
| 14 | +from unittest.mock import ANY, call, Mock |
14 | 15 |
|
15 | 16 | import pytest
|
| 17 | +from torch.utils.data import DataLoader |
16 | 18 |
|
17 | 19 | from pytorch_lightning import Trainer
|
18 | 20 | from pytorch_lightning.callbacks import BasePredictionWriter
|
19 | 21 | from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
20 |
| -from tests.helpers import BoringModel |
| 22 | +from tests.helpers import BoringModel, RandomDataset |
| 23 | +from tests.helpers.runif import RunIf |
21 | 24 |
|
22 | 25 |
|
23 |
| -def test_prediction_writer(tmpdir): |
24 |
| - class CustomPredictionWriter(BasePredictionWriter): |
25 |
| - def __init__(self, writer_interval: str): |
26 |
| - super().__init__(writer_interval) |
| 26 | +class DummyPredictionWriter(BasePredictionWriter): |
| 27 | + def write_on_batch_end(self, *args, **kwargs): |
| 28 | + pass |
27 | 29 |
|
28 |
| - self.write_on_batch_end_called = False |
29 |
| - self.write_on_epoch_end_called = False |
| 30 | + def write_on_epoch_end(self, *args, **kwargs): |
| 31 | + pass |
30 | 32 |
|
31 |
| - def write_on_batch_end(self, *args, **kwargs): |
32 |
| - self.write_on_batch_end_called = True |
33 |
| - |
34 |
| - def write_on_epoch_end(self, *args, **kwargs): |
35 |
| - self.write_on_epoch_end_called = True |
36 | 33 |
|
| 34 | +def test_prediction_writer_invalid_write_interval(): |
| 35 | + """Test that configuring an unknown interval name raises an error.""" |
37 | 36 | with pytest.raises(MisconfigurationException, match=r"`write_interval` should be one of \['batch"):
|
38 |
| - CustomPredictionWriter("something") |
| 37 | + DummyPredictionWriter("something") |
| 38 | + |
| 39 | + |
| 40 | +def test_prediction_writer_hook_call_intervals(tmpdir): |
| 41 | + """Test that the `write_on_batch_end` and `write_on_epoch_end` hooks get invoked based on the defined |
| 42 | + interval.""" |
| 43 | + DummyPredictionWriter.write_on_batch_end = Mock() |
| 44 | + DummyPredictionWriter.write_on_epoch_end = Mock() |
| 45 | + |
| 46 | + dataloader = DataLoader(RandomDataset(32, 64)) |
39 | 47 |
|
40 | 48 | model = BoringModel()
|
41 |
| - cb = CustomPredictionWriter("batch_and_epoch") |
| 49 | + cb = DummyPredictionWriter("batch_and_epoch") |
42 | 50 | trainer = Trainer(limit_predict_batches=4, callbacks=cb)
|
43 |
| - results = trainer.predict(model, dataloaders=model.train_dataloader()) |
| 51 | + results = trainer.predict(model, dataloaders=dataloader) |
44 | 52 | assert len(results) == 4
|
45 |
| - assert cb.write_on_batch_end_called |
46 |
| - assert cb.write_on_epoch_end_called |
| 53 | + assert cb.write_on_batch_end.call_count == 4 |
| 54 | + assert cb.write_on_epoch_end.call_count == 1 |
47 | 55 |
|
48 |
| - cb = CustomPredictionWriter("batch_and_epoch") |
| 56 | + DummyPredictionWriter.write_on_batch_end.reset_mock() |
| 57 | + DummyPredictionWriter.write_on_epoch_end.reset_mock() |
| 58 | + |
| 59 | + cb = DummyPredictionWriter("batch_and_epoch") |
49 | 60 | trainer = Trainer(limit_predict_batches=4, callbacks=cb)
|
50 |
| - trainer.predict(model, dataloaders=model.train_dataloader(), return_predictions=False) |
51 |
| - assert cb.write_on_batch_end_called |
52 |
| - assert cb.write_on_epoch_end_called |
| 61 | + trainer.predict(model, dataloaders=dataloader, return_predictions=False) |
| 62 | + assert cb.write_on_batch_end.call_count == 4 |
| 63 | + assert cb.write_on_epoch_end.call_count == 1 |
| 64 | + |
| 65 | + DummyPredictionWriter.write_on_batch_end.reset_mock() |
| 66 | + DummyPredictionWriter.write_on_epoch_end.reset_mock() |
53 | 67 |
|
54 |
| - cb = CustomPredictionWriter("batch") |
| 68 | + cb = DummyPredictionWriter("batch") |
55 | 69 | trainer = Trainer(limit_predict_batches=4, callbacks=cb)
|
56 |
| - trainer.predict(model, dataloaders=model.train_dataloader(), return_predictions=False) |
57 |
| - assert cb.write_on_batch_end_called |
58 |
| - assert not cb.write_on_epoch_end_called |
| 70 | + trainer.predict(model, dataloaders=dataloader, return_predictions=False) |
| 71 | + assert cb.write_on_batch_end.call_count == 4 |
| 72 | + assert cb.write_on_epoch_end.call_count == 0 |
| 73 | + |
| 74 | + DummyPredictionWriter.write_on_batch_end.reset_mock() |
| 75 | + DummyPredictionWriter.write_on_epoch_end.reset_mock() |
59 | 76 |
|
60 |
| - cb = CustomPredictionWriter("epoch") |
| 77 | + cb = DummyPredictionWriter("epoch") |
61 | 78 | trainer = Trainer(limit_predict_batches=4, callbacks=cb)
|
62 |
| - trainer.predict(model, dataloaders=model.train_dataloader(), return_predictions=False) |
63 |
| - assert not cb.write_on_batch_end_called |
64 |
| - assert cb.write_on_epoch_end_called |
| 79 | + trainer.predict(model, dataloaders=dataloader, return_predictions=False) |
| 80 | + assert cb.write_on_batch_end.call_count == 0 |
| 81 | + assert cb.write_on_epoch_end.call_count == 1 |
| 82 | + |
| 83 | + |
| 84 | +@pytest.mark.parametrize("num_workers", [0, pytest.param(2, marks=RunIf(slow=True))]) |
| 85 | +def test_prediction_writer_batch_indices(tmpdir, num_workers): |
| 86 | + DummyPredictionWriter.write_on_batch_end = Mock() |
| 87 | + DummyPredictionWriter.write_on_epoch_end = Mock() |
| 88 | + |
| 89 | + dataloader = DataLoader(RandomDataset(32, 64), batch_size=4, num_workers=num_workers) |
| 90 | + model = BoringModel() |
| 91 | + writer = DummyPredictionWriter("batch_and_epoch") |
| 92 | + trainer = Trainer(limit_predict_batches=4, callbacks=writer) |
| 93 | + trainer.predict(model, dataloaders=dataloader) |
| 94 | + |
| 95 | + writer.write_on_batch_end.assert_has_calls( |
| 96 | + [ |
| 97 | + call(trainer, model, ANY, [0, 1, 2, 3], ANY, 0, 0), |
| 98 | + call(trainer, model, ANY, [4, 5, 6, 7], ANY, 1, 0), |
| 99 | + call(trainer, model, ANY, [8, 9, 10, 11], ANY, 2, 0), |
| 100 | + call(trainer, model, ANY, [12, 13, 14, 15], ANY, 3, 0), |
| 101 | + ] |
| 102 | + ) |
| 103 | + |
| 104 | + writer.write_on_epoch_end.assert_has_calls( |
| 105 | + [ |
| 106 | + call(trainer, model, ANY, [[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]]), |
| 107 | + ] |
| 108 | + ) |
0 commit comments