Skip to content

Commit c55bc43

Browse files
awaelchlirohitgr7
andauthored
Fix retrieval of batch indices when dataloader num_workers > 0 (#10870)
Co-authored-by: Rohit Gupta <[email protected]>
1 parent 98cb7e8 commit c55bc43

File tree

7 files changed

+133
-54
lines changed

7 files changed

+133
-54
lines changed

CHANGELOG.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
103103
- Deprecated the property `Trainer.slurm_job_id` in favor of the new `SLURMEnvironment.job_id()` method ([#10622](https://github.com/PyTorchLightning/pytorch-lightning/pull/10622))
104104

105105

106-
-
106+
- Deprecated the access to the attribute `IndexBatchSamplerWrapper.batch_indices` in favor of `IndexBatchSamplerWrapper.seen_batch_indices` ([#10870](https://github.com/PyTorchLightning/pytorch-lightning/pull/10870))
107+
107108

108109
### Removed
109110

@@ -227,12 +228,14 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
227228
- Fixed early schedule reset logic in PyTorch profiler that was causing data leak ([#10837](https://github.com/PyTorchLightning/pytorch-lightning/pull/10837))
228229

229230

230-
-
231+
- Fixed a bug that caused incorrect batch indices to be passed to the `BasePredictionWriter` hooks when using a dataloader with `num_workers > 0` ([#10870](https://github.com/PyTorchLightning/pytorch-lightning/pull/10870))
232+
231233

232234

233235
-
234236

235237

238+
236239
## [1.5.4] - 2021-11-30
237240

238241
### Fixed

pytorch_lightning/loops/epoch/prediction_epoch_loop.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def __init__(self) -> None:
2626
self._dl_max_batches = 0
2727
self._num_dataloaders = 0
2828
self._warning_cache = WarningCache()
29-
self._all_batch_indices: List[int] = []
29+
self._seen_batch_indices: List[List[int]] = []
3030

3131
@property
3232
def done(self) -> bool:
@@ -44,7 +44,7 @@ def connect(self, **kwargs: "Loop") -> None:
4444

4545
def reset(self) -> None:
4646
"""Resets the loops internal state."""
47-
self._all_batch_indices = []
47+
self._seen_batch_indices = []
4848
self.predictions = []
4949
self.batch_progress.reset_on_run()
5050

@@ -68,6 +68,7 @@ def on_run_start( # type: ignore[override]
6868
void(dataloader_iter, dataloader_idx)
6969
self._dl_max_batches = dl_max_batches
7070
self._num_dataloaders = num_dataloaders
71+
self._seen_batch_indices = self._get_batch_indices(dataloader_idx)
7172
self.return_predictions = return_predictions
7273

7374
def advance( # type: ignore[override]
@@ -88,6 +89,10 @@ def advance( # type: ignore[override]
8889
return_predictions: whether to return the obtained predictions
8990
"""
9091
batch_idx, batch = next(dataloader_iter)
92+
self._seen_batch_indices = self._get_batch_indices(dataloader_idx)
93+
# we need to truncate the list of batch indicies due to prefetching in the dataloader and Lightning
94+
self._seen_batch_indices = self._seen_batch_indices[: (self.batch_progress.current.completed + 1)]
95+
9196
if batch is None:
9297
raise StopIteration
9398

@@ -99,13 +104,10 @@ def advance( # type: ignore[override]
99104
with self.trainer.profiler.profile("predict_step"):
100105
self._predict_step(batch, batch_idx, dataloader_idx)
101106

102-
def on_run_end(self) -> Tuple[List[Any], List[int]]:
107+
def on_run_end(self) -> Tuple[List[Any], List[List[int]]]:
103108
"""Returns the predictions and the corresponding batch indices."""
104-
predictions = self.predictions
105-
all_batch_indices = self._all_batch_indices
106-
# free memory
107-
self.predictions = []
108-
self._all_batch_indices = []
109+
predictions, all_batch_indices = self.predictions, self._seen_batch_indices
110+
self.predictions, self._seen_batch_indices = [], [] # free memory
109111
return predictions, all_batch_indices
110112

111113
def _predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
@@ -121,7 +123,7 @@ def _predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None
121123
step_kwargs = self._build_kwargs(batch, batch_idx, dataloader_idx)
122124

123125
# extract batch_indices and store them
124-
self._store_batch_indices(dataloader_idx)
126+
self.current_batch_indices = self._seen_batch_indices[batch_idx] if self._seen_batch_indices else []
125127

126128
model_ref = self.trainer.lightning_module
127129

@@ -160,12 +162,12 @@ def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Dict
160162
step_kwargs["dataloader_idx"] = dataloader_idx
161163
return step_kwargs
162164

163-
def _store_batch_indices(self, dataloader_idx: int) -> None:
164-
"""Stores the batch indices if the predictions should be stored."""
165+
def _get_batch_indices(self, dataloader_idx: int) -> List[List[int]]:
166+
"""Returns a reference to the seen batch indices if the dataloader has a batch sampler wrapped by our
167+
:class:`~pytorch_lightning.overrides.distributed.IndexBatchSamplerWrapper`."""
165168
batch_sampler = self.trainer.predict_dataloaders[dataloader_idx].batch_sampler
166-
if isinstance(batch_sampler, IndexBatchSamplerWrapper):
167-
self.current_batch_indices = batch_sampler.batch_indices
168-
if self.should_store_predictions:
169-
self._all_batch_indices.append(batch_sampler.batch_indices)
170-
else:
171-
warning_cache.warn("Lightning couldn't infer the indices fetched for your dataloader.")
169+
if isinstance(batch_sampler, IndexBatchSamplerWrapper) and self.should_store_predictions:
170+
return batch_sampler.seen_batch_indices
171+
172+
warning_cache.warn("Lightning couldn't infer the indices fetched for your dataloader.")
173+
return []

pytorch_lightning/overrides/distributed.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import itertools
15-
from typing import Any, cast, Iterator, List, Optional, Sized, Union
15+
from typing import Any, cast, Iterator, List, Sized, Union
1616

1717
import torch
1818
from torch import Tensor
@@ -21,6 +21,7 @@
2121

2222
import pytorch_lightning as pl
2323
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
24+
from pytorch_lightning.utilities import rank_zero_deprecation
2425

2526

2627
class LightningDistributedModule(_LightningModuleWrapperBase):
@@ -123,12 +124,31 @@ class IndexBatchSamplerWrapper:
123124
"""This class is used to wrap a :class:`torch.utils.data.BatchSampler` and capture its indices."""
124125

125126
def __init__(self, sampler: BatchSampler) -> None:
127+
self.seen_batch_indices: List[List[int]] = []
126128
self._sampler = sampler
127-
self.batch_indices: Optional[List[int]] = None
129+
self._batch_indices: List[int] = []
130+
131+
@property
132+
def batch_indices(self) -> List[int]:
133+
rank_zero_deprecation(
134+
"The attribute `IndexBatchSamplerWrapper.batch_indices` was deprecated in v1.5 and will be removed in"
135+
" v1.7. Access the full list `seen_batch_indices` instead."
136+
)
137+
return self._batch_indices
138+
139+
@batch_indices.setter
140+
def batch_indices(self, indices: List[int]) -> None:
141+
rank_zero_deprecation(
142+
"The attribute `IndexBatchSamplerWrapper.batch_indices` was deprecated in v1.5 and will be removed in"
143+
" v1.7. Access the full list `seen_batch_indices` instead."
144+
)
145+
self._batch_indices = indices
128146

129147
def __iter__(self) -> Iterator[List[int]]:
148+
self.seen_batch_indices = []
130149
for batch in self._sampler:
131-
self.batch_indices = batch
150+
self._batch_indices = batch
151+
self.seen_batch_indices.append(batch)
132152
yield batch
133153

134154
def __len__(self) -> int:

tests/callbacks/test_prediction_writer.py

Lines changed: 73 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -11,54 +11,98 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from unittest.mock import ANY, call, Mock
1415

1516
import pytest
17+
from torch.utils.data import DataLoader
1618

1719
from pytorch_lightning import Trainer
1820
from pytorch_lightning.callbacks import BasePredictionWriter
1921
from pytorch_lightning.utilities.exceptions import MisconfigurationException
20-
from tests.helpers import BoringModel
22+
from tests.helpers import BoringModel, RandomDataset
23+
from tests.helpers.runif import RunIf
2124

2225

23-
def test_prediction_writer(tmpdir):
24-
class CustomPredictionWriter(BasePredictionWriter):
25-
def __init__(self, writer_interval: str):
26-
super().__init__(writer_interval)
26+
class DummyPredictionWriter(BasePredictionWriter):
27+
def write_on_batch_end(self, *args, **kwargs):
28+
pass
2729

28-
self.write_on_batch_end_called = False
29-
self.write_on_epoch_end_called = False
30+
def write_on_epoch_end(self, *args, **kwargs):
31+
pass
3032

31-
def write_on_batch_end(self, *args, **kwargs):
32-
self.write_on_batch_end_called = True
33-
34-
def write_on_epoch_end(self, *args, **kwargs):
35-
self.write_on_epoch_end_called = True
3633

34+
def test_prediction_writer_invalid_write_interval():
35+
"""Test that configuring an unknown interval name raises an error."""
3736
with pytest.raises(MisconfigurationException, match=r"`write_interval` should be one of \['batch"):
38-
CustomPredictionWriter("something")
37+
DummyPredictionWriter("something")
38+
39+
40+
def test_prediction_writer_hook_call_intervals(tmpdir):
41+
"""Test that the `write_on_batch_end` and `write_on_epoch_end` hooks get invoked based on the defined
42+
interval."""
43+
DummyPredictionWriter.write_on_batch_end = Mock()
44+
DummyPredictionWriter.write_on_epoch_end = Mock()
45+
46+
dataloader = DataLoader(RandomDataset(32, 64))
3947

4048
model = BoringModel()
41-
cb = CustomPredictionWriter("batch_and_epoch")
49+
cb = DummyPredictionWriter("batch_and_epoch")
4250
trainer = Trainer(limit_predict_batches=4, callbacks=cb)
43-
results = trainer.predict(model, dataloaders=model.train_dataloader())
51+
results = trainer.predict(model, dataloaders=dataloader)
4452
assert len(results) == 4
45-
assert cb.write_on_batch_end_called
46-
assert cb.write_on_epoch_end_called
53+
assert cb.write_on_batch_end.call_count == 4
54+
assert cb.write_on_epoch_end.call_count == 1
4755

48-
cb = CustomPredictionWriter("batch_and_epoch")
56+
DummyPredictionWriter.write_on_batch_end.reset_mock()
57+
DummyPredictionWriter.write_on_epoch_end.reset_mock()
58+
59+
cb = DummyPredictionWriter("batch_and_epoch")
4960
trainer = Trainer(limit_predict_batches=4, callbacks=cb)
50-
trainer.predict(model, dataloaders=model.train_dataloader(), return_predictions=False)
51-
assert cb.write_on_batch_end_called
52-
assert cb.write_on_epoch_end_called
61+
trainer.predict(model, dataloaders=dataloader, return_predictions=False)
62+
assert cb.write_on_batch_end.call_count == 4
63+
assert cb.write_on_epoch_end.call_count == 1
64+
65+
DummyPredictionWriter.write_on_batch_end.reset_mock()
66+
DummyPredictionWriter.write_on_epoch_end.reset_mock()
5367

54-
cb = CustomPredictionWriter("batch")
68+
cb = DummyPredictionWriter("batch")
5569
trainer = Trainer(limit_predict_batches=4, callbacks=cb)
56-
trainer.predict(model, dataloaders=model.train_dataloader(), return_predictions=False)
57-
assert cb.write_on_batch_end_called
58-
assert not cb.write_on_epoch_end_called
70+
trainer.predict(model, dataloaders=dataloader, return_predictions=False)
71+
assert cb.write_on_batch_end.call_count == 4
72+
assert cb.write_on_epoch_end.call_count == 0
73+
74+
DummyPredictionWriter.write_on_batch_end.reset_mock()
75+
DummyPredictionWriter.write_on_epoch_end.reset_mock()
5976

60-
cb = CustomPredictionWriter("epoch")
77+
cb = DummyPredictionWriter("epoch")
6178
trainer = Trainer(limit_predict_batches=4, callbacks=cb)
62-
trainer.predict(model, dataloaders=model.train_dataloader(), return_predictions=False)
63-
assert not cb.write_on_batch_end_called
64-
assert cb.write_on_epoch_end_called
79+
trainer.predict(model, dataloaders=dataloader, return_predictions=False)
80+
assert cb.write_on_batch_end.call_count == 0
81+
assert cb.write_on_epoch_end.call_count == 1
82+
83+
84+
@pytest.mark.parametrize("num_workers", [0, pytest.param(2, marks=RunIf(slow=True))])
85+
def test_prediction_writer_batch_indices(tmpdir, num_workers):
86+
DummyPredictionWriter.write_on_batch_end = Mock()
87+
DummyPredictionWriter.write_on_epoch_end = Mock()
88+
89+
dataloader = DataLoader(RandomDataset(32, 64), batch_size=4, num_workers=num_workers)
90+
model = BoringModel()
91+
writer = DummyPredictionWriter("batch_and_epoch")
92+
trainer = Trainer(limit_predict_batches=4, callbacks=writer)
93+
trainer.predict(model, dataloaders=dataloader)
94+
95+
writer.write_on_batch_end.assert_has_calls(
96+
[
97+
call(trainer, model, ANY, [0, 1, 2, 3], ANY, 0, 0),
98+
call(trainer, model, ANY, [4, 5, 6, 7], ANY, 1, 0),
99+
call(trainer, model, ANY, [8, 9, 10, 11], ANY, 2, 0),
100+
call(trainer, model, ANY, [12, 13, 14, 15], ANY, 3, 0),
101+
]
102+
)
103+
104+
writer.write_on_epoch_end.assert_has_calls(
105+
[
106+
call(trainer, model, ANY, [[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]]),
107+
]
108+
)

tests/deprecated_api/test_remove_1-7.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
"""Test deprecated functionality which will be removed in v1.7.0."""
1515
import os
1616
from unittest import mock
17+
from unittest.mock import Mock
1718

1819
import pytest
1920

@@ -23,6 +24,7 @@
2324
from pytorch_lightning.callbacks.progress import ProgressBar
2425
from pytorch_lightning.callbacks.xla_stats_monitor import XLAStatsMonitor
2526
from pytorch_lightning.loggers import LoggerCollection, TestTubeLogger
27+
from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper
2628
from pytorch_lightning.plugins.environments import (
2729
KubeflowEnvironment,
2830
LightningEnvironment,
@@ -528,3 +530,12 @@ def is_using_torchelastic():
528530
match=f"MyClusterEnvironment.{method_name}` has been deprecated in v1.6 and will be removed in v1.7"
529531
):
530532
MyClusterEnvironment()
533+
534+
535+
def test_v1_7_0_index_batch_sampler_wrapper_batch_indices():
536+
sampler = IndexBatchSamplerWrapper(Mock())
537+
with pytest.deprecated_call(match="was deprecated in v1.5 and will be removed in v1.7"):
538+
_ = sampler.batch_indices
539+
540+
with pytest.deprecated_call(match="was deprecated in v1.5 and will be removed in v1.7"):
541+
sampler.batch_indices = []

tests/deprecated_api/test_remove_1-8.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""Test deprecated functionality which will be removed in v1.8.0."""
15+
1516
import pytest
1617
import torch
1718

tests/overrides/test_distributed.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,7 @@ def test_index_batch_sampler(tmpdir):
5454
assert batch_sampler.batch_size == index_batch_sampler.batch_size
5555
assert batch_sampler.drop_last == index_batch_sampler.drop_last
5656
assert batch_sampler.sampler is sampler
57-
58-
for batch in index_batch_sampler:
59-
assert index_batch_sampler.batch_indices == batch
57+
assert list(index_batch_sampler) == index_batch_sampler.seen_batch_indices
6058

6159

6260
def test_index_batch_sampler_methods():

0 commit comments

Comments
 (0)