Skip to content

Commit 84bdcd4

Browse files
awaelchlirohitgr7
andcommitted
Fix retrieval of batch indices when dataloader num_workers > 0 (#10870)
Co-authored-by: Rohit Gupta <[email protected]>
1 parent f26f637 commit 84bdcd4

File tree

6 files changed

+129
-63
lines changed

6 files changed

+129
-63
lines changed

CHANGELOG.md

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1212
- Fixed an issue with `SignalConnector` not restoring the default signal handlers on teardown when running on SLURM or with fault-tolerant training enabled ([#10611](https://github.com/PyTorchLightning/pytorch-lightning/pull/10611))
1313
- Fixed `SignalConnector._has_already_handler` check for callable type ([#10483](https://github.com/PyTorchLightning/pytorch-lightning/pull/10483))
1414
- Improved exception message if `rich` version is less than `10.2.2` ([#10839](https://github.com/PyTorchLightning/pytorch-lightning/pull/10839))
15-
16-
1715
- Fixed uploading best model checkpoint in NeptuneLogger ([#10369](https://github.com/PyTorchLightning/pytorch-lightning/pull/10369))
18-
19-
2016
- Fixed early schedule reset logic in PyTorch profiler that was causing data leak ([#10837](https://github.com/PyTorchLightning/pytorch-lightning/pull/10837))
21-
22-
23-
-
24-
25-
26-
-
17+
- Fixed a bug that caused incorrect batch indices to be passed to the `BasePredictionWriter` hooks when using a dataloader with `num_workers > 0` ([#10870](https://github.com/PyTorchLightning/pytorch-lightning/pull/10870))
2718

2819

2920
## [1.5.4] - 2021-11-30

pytorch_lightning/loops/epoch/prediction_epoch_loop.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def __init__(self) -> None:
2626
self._dl_max_batches: Optional[int] = None
2727
self._num_dataloaders: Optional[int] = None
2828
self._warning_cache = WarningCache()
29-
self._all_batch_indices: List[int] = []
29+
self._seen_batch_indices: List[List[int]] = []
3030

3131
@property
3232
def done(self) -> bool:
@@ -44,8 +44,8 @@ def connect(self, **kwargs: "Loop") -> None:
4444

4545
def reset(self) -> None:
4646
"""Resets the loops internal state."""
47-
self._all_batch_indices: List[int] = []
48-
self.predictions: List[Any] = []
47+
self._seen_batch_indices = []
48+
self.predictions = []
4949
self.batch_progress.reset_on_run()
5050

5151
def on_run_start(
@@ -68,6 +68,7 @@ def on_run_start(
6868
void(dataloader_iter, dataloader_idx)
6969
self._dl_max_batches = dl_max_batches
7070
self._num_dataloaders = num_dataloaders
71+
self._seen_batch_indices = self._get_batch_indices(dataloader_idx)
7172
self.return_predictions = return_predictions
7273

7374
def advance(
@@ -88,6 +89,10 @@ def advance(
8889
return_predictions: whether to return the obtained predictions
8990
"""
9091
batch_idx, batch = next(dataloader_iter)
92+
self._seen_batch_indices = self._get_batch_indices(dataloader_idx)
93+
# we need to truncate the list of batch indicies due to prefetching in the dataloader and Lightning
94+
self._seen_batch_indices = self._seen_batch_indices[: (self.batch_progress.current.completed + 1)]
95+
9196
if batch is None:
9297
raise StopIteration
9398

@@ -99,13 +104,10 @@ def advance(
99104
with self.trainer.profiler.profile("predict_step"):
100105
self._predict_step(batch, batch_idx, dataloader_idx)
101106

102-
def on_run_end(self) -> Tuple[List[Any], List[int]]:
107+
def on_run_end(self) -> Tuple[List[Any], List[List[int]]]:
103108
"""Returns the predictions and the corresponding batch indices."""
104-
predictions = self.predictions
105-
all_batch_indices = self._all_batch_indices
106-
# free memory
107-
self.predictions = []
108-
self._all_batch_indices = []
109+
predictions, all_batch_indices = self.predictions, self._seen_batch_indices
110+
self.predictions, self._seen_batch_indices = [], [] # free memory
109111
return predictions, all_batch_indices
110112

111113
def _predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
@@ -121,7 +123,7 @@ def _predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None
121123
step_kwargs = self._build_kwargs(batch, batch_idx, dataloader_idx)
122124

123125
# extract batch_indices and store them
124-
self._store_batch_indices(dataloader_idx)
126+
self.current_batch_indices = self._seen_batch_indices[batch_idx] if self._seen_batch_indices else []
125127

126128
model_ref = self.trainer.lightning_module
127129

@@ -160,12 +162,12 @@ def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Dict
160162
step_kwargs["dataloader_idx"] = dataloader_idx
161163
return step_kwargs
162164

163-
def _store_batch_indices(self, dataloader_idx: int) -> None:
164-
"""Stores the batch indices if the predictions should be stored."""
165+
def _get_batch_indices(self, dataloader_idx: int) -> List[List[int]]:
166+
"""Returns a reference to the seen batch indices if the dataloader has a batch sampler wrapped by our
167+
:class:`~pytorch_lightning.overrides.distributed.IndexBatchSamplerWrapper`."""
165168
batch_sampler = self.trainer.predict_dataloaders[dataloader_idx].batch_sampler
166-
if isinstance(batch_sampler, IndexBatchSamplerWrapper):
167-
self.current_batch_indices = batch_sampler.batch_indices
168-
if self.should_store_predictions:
169-
self._all_batch_indices.append(batch_sampler.batch_indices)
170-
else:
171-
warning_cache.warn("Lightning couldn't infer the indices fetched for your dataloader.")
169+
if isinstance(batch_sampler, IndexBatchSamplerWrapper) and self.should_store_predictions:
170+
return batch_sampler.seen_batch_indices
171+
172+
warning_cache.warn("Lightning couldn't infer the indices fetched for your dataloader.")
173+
return []

pytorch_lightning/overrides/distributed.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,15 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import itertools
15-
from typing import Any, Iterator, List, Optional
15+
from typing import Any, Iterator, List
1616

1717
import torch
1818
from torch.nn.parallel import DistributedDataParallel
1919
from torch.utils.data import BatchSampler, DistributedSampler, Sampler
2020

2121
import pytorch_lightning as pl
2222
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
23+
from pytorch_lightning.utilities import rank_zero_deprecation
2324

2425

2526
class LightningDistributedModule(_LightningModuleWrapperBase):
@@ -119,12 +120,31 @@ class IndexBatchSamplerWrapper:
119120
"""This class is used to wrap a :class:`torch.utils.data.BatchSampler` and capture its indices."""
120121

121122
def __init__(self, sampler: BatchSampler) -> None:
123+
self.seen_batch_indices: List[List[int]] = []
122124
self._sampler = sampler
123-
self.batch_indices: Optional[List[int]] = None
125+
self._batch_indices: List[int] = []
126+
127+
@property
128+
def batch_indices(self) -> List[int]:
129+
rank_zero_deprecation(
130+
"The attribute `IndexBatchSamplerWrapper.batch_indices` was deprecated in v1.5 and will be removed in"
131+
" v1.7. Access the full list `seen_batch_indices` instead."
132+
)
133+
return self._batch_indices
134+
135+
@batch_indices.setter
136+
def batch_indices(self, indices: List[int]) -> None:
137+
rank_zero_deprecation(
138+
"The attribute `IndexBatchSamplerWrapper.batch_indices` was deprecated in v1.5 and will be removed in"
139+
" v1.7. Access the full list `seen_batch_indices` instead."
140+
)
141+
self._batch_indices = indices
124142

125143
def __iter__(self) -> Iterator[List[int]]:
144+
self.seen_batch_indices = []
126145
for batch in self._sampler:
127-
self.batch_indices = batch
146+
self._batch_indices = batch
147+
self.seen_batch_indices.append(batch)
128148
yield batch
129149

130150
def __len__(self) -> int:

tests/callbacks/test_prediction_writer.py

Lines changed: 73 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -11,54 +11,98 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from unittest.mock import ANY, call, Mock
1415

1516
import pytest
17+
from torch.utils.data import DataLoader
1618

1719
from pytorch_lightning import Trainer
1820
from pytorch_lightning.callbacks import BasePredictionWriter
1921
from pytorch_lightning.utilities.exceptions import MisconfigurationException
20-
from tests.helpers import BoringModel
22+
from tests.helpers import BoringModel, RandomDataset
23+
from tests.helpers.runif import RunIf
2124

2225

23-
def test_prediction_writer(tmpdir):
24-
class CustomPredictionWriter(BasePredictionWriter):
25-
def __init__(self, writer_interval: str):
26-
super().__init__(writer_interval)
26+
class DummyPredictionWriter(BasePredictionWriter):
27+
def write_on_batch_end(self, *args, **kwargs):
28+
pass
2729

28-
self.write_on_batch_end_called = False
29-
self.write_on_epoch_end_called = False
30+
def write_on_epoch_end(self, *args, **kwargs):
31+
pass
3032

31-
def write_on_batch_end(self, *args, **kwargs):
32-
self.write_on_batch_end_called = True
33-
34-
def write_on_epoch_end(self, *args, **kwargs):
35-
self.write_on_epoch_end_called = True
3633

34+
def test_prediction_writer_invalid_write_interval():
35+
"""Test that configuring an unknown interval name raises an error."""
3736
with pytest.raises(MisconfigurationException, match=r"`write_interval` should be one of \['batch"):
38-
CustomPredictionWriter("something")
37+
DummyPredictionWriter("something")
38+
39+
40+
def test_prediction_writer_hook_call_intervals(tmpdir):
41+
"""Test that the `write_on_batch_end` and `write_on_epoch_end` hooks get invoked based on the defined
42+
interval."""
43+
DummyPredictionWriter.write_on_batch_end = Mock()
44+
DummyPredictionWriter.write_on_epoch_end = Mock()
45+
46+
dataloader = DataLoader(RandomDataset(32, 64))
3947

4048
model = BoringModel()
41-
cb = CustomPredictionWriter("batch_and_epoch")
49+
cb = DummyPredictionWriter("batch_and_epoch")
4250
trainer = Trainer(limit_predict_batches=4, callbacks=cb)
43-
results = trainer.predict(model, dataloaders=model.train_dataloader())
51+
results = trainer.predict(model, dataloaders=dataloader)
4452
assert len(results) == 4
45-
assert cb.write_on_batch_end_called
46-
assert cb.write_on_epoch_end_called
53+
assert cb.write_on_batch_end.call_count == 4
54+
assert cb.write_on_epoch_end.call_count == 1
4755

48-
cb = CustomPredictionWriter("batch_and_epoch")
56+
DummyPredictionWriter.write_on_batch_end.reset_mock()
57+
DummyPredictionWriter.write_on_epoch_end.reset_mock()
58+
59+
cb = DummyPredictionWriter("batch_and_epoch")
4960
trainer = Trainer(limit_predict_batches=4, callbacks=cb)
50-
trainer.predict(model, dataloaders=model.train_dataloader(), return_predictions=False)
51-
assert cb.write_on_batch_end_called
52-
assert cb.write_on_epoch_end_called
61+
trainer.predict(model, dataloaders=dataloader, return_predictions=False)
62+
assert cb.write_on_batch_end.call_count == 4
63+
assert cb.write_on_epoch_end.call_count == 1
64+
65+
DummyPredictionWriter.write_on_batch_end.reset_mock()
66+
DummyPredictionWriter.write_on_epoch_end.reset_mock()
5367

54-
cb = CustomPredictionWriter("batch")
68+
cb = DummyPredictionWriter("batch")
5569
trainer = Trainer(limit_predict_batches=4, callbacks=cb)
56-
trainer.predict(model, dataloaders=model.train_dataloader(), return_predictions=False)
57-
assert cb.write_on_batch_end_called
58-
assert not cb.write_on_epoch_end_called
70+
trainer.predict(model, dataloaders=dataloader, return_predictions=False)
71+
assert cb.write_on_batch_end.call_count == 4
72+
assert cb.write_on_epoch_end.call_count == 0
73+
74+
DummyPredictionWriter.write_on_batch_end.reset_mock()
75+
DummyPredictionWriter.write_on_epoch_end.reset_mock()
5976

60-
cb = CustomPredictionWriter("epoch")
77+
cb = DummyPredictionWriter("epoch")
6178
trainer = Trainer(limit_predict_batches=4, callbacks=cb)
62-
trainer.predict(model, dataloaders=model.train_dataloader(), return_predictions=False)
63-
assert not cb.write_on_batch_end_called
64-
assert cb.write_on_epoch_end_called
79+
trainer.predict(model, dataloaders=dataloader, return_predictions=False)
80+
assert cb.write_on_batch_end.call_count == 0
81+
assert cb.write_on_epoch_end.call_count == 1
82+
83+
84+
@pytest.mark.parametrize("num_workers", [0, pytest.param(2, marks=RunIf(slow=True))])
85+
def test_prediction_writer_batch_indices(tmpdir, num_workers):
86+
DummyPredictionWriter.write_on_batch_end = Mock()
87+
DummyPredictionWriter.write_on_epoch_end = Mock()
88+
89+
dataloader = DataLoader(RandomDataset(32, 64), batch_size=4, num_workers=num_workers)
90+
model = BoringModel()
91+
writer = DummyPredictionWriter("batch_and_epoch")
92+
trainer = Trainer(limit_predict_batches=4, callbacks=writer)
93+
trainer.predict(model, dataloaders=dataloader)
94+
95+
writer.write_on_batch_end.assert_has_calls(
96+
[
97+
call(trainer, model, ANY, [0, 1, 2, 3], ANY, 0, 0),
98+
call(trainer, model, ANY, [4, 5, 6, 7], ANY, 1, 0),
99+
call(trainer, model, ANY, [8, 9, 10, 11], ANY, 2, 0),
100+
call(trainer, model, ANY, [12, 13, 14, 15], ANY, 3, 0),
101+
]
102+
)
103+
104+
writer.write_on_epoch_end.assert_has_calls(
105+
[
106+
call(trainer, model, ANY, [[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]]),
107+
]
108+
)

tests/deprecated_api/test_remove_1-7.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
"""Test deprecated functionality which will be removed in v1.7.0."""
1515
from unittest import mock
16+
from unittest.mock import Mock
1617

1718
import pytest
1819

@@ -22,6 +23,7 @@
2223
from pytorch_lightning.callbacks.progress import ProgressBar
2324
from pytorch_lightning.callbacks.xla_stats_monitor import XLAStatsMonitor
2425
from pytorch_lightning.loggers import LoggerCollection, TestTubeLogger
26+
from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper
2527
from tests.callbacks.test_callbacks import OldStatefulCallback
2628
from tests.deprecated_api import _soft_unimport_module
2729
from tests.helpers import BoringModel
@@ -448,3 +450,12 @@ def test_v1_7_0_deprecate_lr_sch_names(tmpdir):
448450

449451
with pytest.deprecated_call(match="`LearningRateMonitor.lr_sch_names` has been deprecated in v1.5"):
450452
assert lr_monitor.lr_sch_names == ["lr-SGD"]
453+
454+
455+
def test_v1_7_0_index_batch_sampler_wrapper_batch_indices():
456+
sampler = IndexBatchSamplerWrapper(Mock())
457+
with pytest.deprecated_call(match="was deprecated in v1.5 and will be removed in v1.7"):
458+
_ = sampler.batch_indices
459+
460+
with pytest.deprecated_call(match="was deprecated in v1.5 and will be removed in v1.7"):
461+
sampler.batch_indices = []

tests/overrides/test_distributed.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,7 @@ def test_index_batch_sampler(tmpdir):
5454
assert batch_sampler.batch_size == index_batch_sampler.batch_size
5555
assert batch_sampler.drop_last == index_batch_sampler.drop_last
5656
assert batch_sampler.sampler is sampler
57-
58-
for batch in index_batch_sampler:
59-
assert index_batch_sampler.batch_indices == batch
57+
assert list(index_batch_sampler) == index_batch_sampler.seen_batch_indices
6058

6159

6260
def test_index_batch_sampler_methods():

0 commit comments

Comments
 (0)