Skip to content

Commit 6309a59

Browse files
authored
Do not prefetch when possible (#12101)
1 parent ed7ccca commit 6309a59

File tree

12 files changed

+142
-105
lines changed

12 files changed

+142
-105
lines changed

docs/source/guides/data.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,9 @@ option when using sequential data.
393393
to ``limit_{mode}_batches``, if it is set to 1.0 it will run for the whole dataset, otherwise it will throw an exception.
394394
Here ``mode`` can be train/val/test/predict.
395395

396+
When iterable datasets are used, Lightning will pre-fetch 1 batch (in addition to the current batch) so it can detect
397+
when the training will stop and run validation if necessary.
398+
396399
.. testcode::
397400

398401
# IterableDataset

pytorch_lightning/loops/dataloader/evaluation_loop.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,13 @@ def dataloaders(self) -> Sequence[DataLoader]:
8181
raise RuntimeError("Dataloaders should be available.")
8282
return dataloaders
8383

84+
@property
85+
def prefetch_batches(self) -> int:
86+
batches = self.trainer.num_test_batches if self.trainer.testing else self.trainer.num_val_batches
87+
is_unsized = batches[self.current_dataloader_idx] == float("inf")
88+
inter_batch_parallelism = os.getenv("PL_INTER_BATCH_PARALLELISM", "0") == "1"
89+
return 1 if is_unsized or inter_batch_parallelism else 0
90+
8491
def connect(self, epoch_loop: EvaluationEpochLoop) -> None: # type: ignore[override]
8592
"""Connect the evaluation epoch loop with this loop."""
8693
self.epoch_loop = epoch_loop
@@ -121,7 +128,7 @@ def on_run_start(self, *args: Any, **kwargs: Any) -> None:
121128
void(*args, **kwargs)
122129

123130
data_fetcher_cls = _select_data_fetcher_type(self.trainer)
124-
self._data_fetcher = data_fetcher_cls()
131+
self._data_fetcher = data_fetcher_cls(prefetch_batches=self.prefetch_batches)
125132

126133
# hook
127134
self._on_evaluation_model_eval()

pytorch_lightning/loops/epoch/evaluation_epoch_loop.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ def on_run_start( # type: ignore[override]
8585
self._reload_dataloader_state_dict(data_fetcher)
8686
# creates the iterator inside the fetcher but returns `self`
8787
self._data_fetcher = cast(AbstractDataFetcher, iter(data_fetcher))
88+
# add the previous `fetched` value to properly track `is_last_batch` with no prefetching
89+
data_fetcher.fetched += self.batch_progress.current.ready
8890

8991
def advance( # type: ignore[override]
9092
self,

pytorch_lightning/loops/epoch/training_epoch_loop.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,9 @@ def reset(self) -> None:
142142

143143
def on_run_start(self, data_fetcher: AbstractDataFetcher) -> None: # type: ignore[override]
144144
self._reload_dataloader_state_dict(data_fetcher)
145-
iter(data_fetcher) # creates the iterator inside the fetcher
145+
_ = iter(data_fetcher) # creates the iterator inside the fetcher
146+
# add the previous `fetched` value to properly track `is_last_batch` with no prefetching
147+
data_fetcher.fetched += self.batch_progress.current.ready
146148

147149
def advance(self, data_fetcher: AbstractDataFetcher) -> None: # type: ignore[override]
148150
"""Runs a single training batch.

pytorch_lightning/loops/fit_loop.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,12 @@ def restarting(self, restarting: bool) -> None:
149149
restarting &= finished_before_on_train_end
150150
Loop.restarting.fset(self, restarting) # call the parent setter
151151

152+
@property
153+
def prefetch_batches(self) -> int:
154+
is_unsized = self.trainer.num_training_batches == float("inf")
155+
inter_batch_parallelism = os.getenv("PL_INTER_BATCH_PARALLELISM", "0") == "1"
156+
return 1 if is_unsized or inter_batch_parallelism else 0
157+
152158
@property
153159
def _skip_backward(self) -> bool:
154160
"""Determines whether the loop will skip backward during automatic optimization."""
@@ -213,8 +219,9 @@ def on_run_start(self) -> None: # type: ignore[override]
213219
"""Calls the ``on_train_start`` hook."""
214220
# reset train dataloader and val dataloader
215221
self.trainer.reset_train_val_dataloaders(self.trainer.lightning_module)
222+
216223
data_fetcher_cls = _select_data_fetcher(self.trainer)
217-
self._data_fetcher = data_fetcher_cls()
224+
self._data_fetcher = data_fetcher_cls(prefetch_batches=self.prefetch_batches)
218225

219226
self._is_fresh_start_epoch = True
220227
self._results.to(device=self.trainer.lightning_module.device)

pytorch_lightning/utilities/data.py

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -89,17 +89,13 @@ def has_iterable_dataset(dataloader: DataLoader) -> bool:
8989

9090
def has_len(dataloader: Union[DataLoader, Iterable]) -> bool:
9191
"""Checks if a given Dataloader has ``__len__`` method implemented i.e. if it is a finite dataloader or
92-
infinite dataloader.
93-
94-
Raises:
95-
ValueError:
96-
If the length of Dataloader is 0, as it requires at least one batch
97-
"""
98-
92+
infinite dataloader."""
9993
try:
10094
# try getting the length
10195
if len(dataloader) == 0:
102-
raise ValueError("`Dataloader` returned 0 length. Please make sure that it returns at least 1 batch")
96+
rank_zero_warn(
97+
f"`{dataloader.__class__.__name__}` returned 0 length. Please make sure this was your intention."
98+
)
10399
has_len = True
104100
except TypeError:
105101
has_len = False
@@ -122,30 +118,27 @@ def has_len_all_ranks(
122118
model: Union["pl.LightningModule", "pl.LightningDataModule"],
123119
) -> bool:
124120
"""Checks if a given Dataloader has ``__len__`` method implemented i.e. if it is a finite dataloader or
125-
infinite dataloader.
126-
127-
Raises:
128-
ValueError:
129-
If the length of Dataloader is 0, as it requires at least one batch
130-
"""
121+
infinite dataloader."""
131122
try:
132-
total_length = training_type.reduce(torch.tensor(len(dataloader)).to(model.device), reduce_op="sum")
133123
local_length = len(dataloader)
124+
total_length = training_type.reduce(torch.tensor(local_length).to(model.device), reduce_op="sum")
134125

135126
if total_length == 0:
136-
raise MisconfigurationException(
137-
"Total length of `Dataloader` across ranks is zero. Please make sure that it returns at least 1 batch."
127+
rank_zero_warn(
128+
f"Total length of `{dataloader.__class__.__name__}` across ranks is zero."
129+
" Please make sure this was your intention."
138130
)
139131
if total_length > 0 and local_length == 0:
140132
if model.allow_zero_length_dataloader_with_multiple_devices:
141133
rank_zero_warn(
142-
"Total length of `Dataloader` across ranks is zero, but local rank has zero length."
143-
" Please be cautious of uneven batch length."
134+
f"Total length of `{dataloader.__class__.__name__}` across ranks is zero, but local rank has zero"
135+
" length. Please be cautious of uneven batch length."
144136
)
145137
has_len = False
146138
else:
147139
raise MisconfigurationException(
148-
"`Dataloader` within local rank has zero length. Please make sure that it returns at least 1 batch."
140+
f"`{dataloader.__class__.__name__}` within local rank has zero length."
141+
" Please make sure that it returns at least 1 batch."
149142
)
150143
else:
151144
has_len = True

pytorch_lightning/utilities/fetching.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from abc import ABC, abstractmethod
1616
from collections.abc import Iterable, Iterator
1717
from copy import deepcopy
18-
from typing import Any, Callable, List, Optional, Tuple
18+
from typing import Any, Callable, List, Optional, Sized, Tuple
1919

2020
import torch
2121
from torch.utils.data.dataloader import DataLoader
@@ -30,6 +30,7 @@
3030
MergedIteratorState,
3131
patch_dataloader_iterator,
3232
)
33+
from pytorch_lightning.utilities.data import has_len
3334
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3435
from pytorch_lightning.utilities.imports import _fault_tolerant_training
3536

@@ -79,6 +80,8 @@ def __init__(self, prefetch_batches: int = 0) -> None:
7980
def setup(self, dataloader: Iterable, **kwargs: Any) -> None:
8081
self._add_capture_metadata_collate(dataloader)
8182
self._dataloader = dataloader
83+
_patch_dataloader_get_iterators()
84+
self._attach_data_fetcher()
8285

8386
@property
8487
def dataloader(self) -> Iterable:
@@ -172,8 +175,6 @@ def _attach_data_fetcher_fn(loader: DataLoader) -> None:
172175

173176
def __iter__(self) -> "AbstractDataFetcher":
174177
self.reset()
175-
self._attach_data_fetcher()
176-
_patch_dataloader_get_iterators()
177178
self.dataloader_iter = iter(self.dataloader)
178179
self._apply_patch()
179180
self.prefetching()
@@ -205,7 +206,7 @@ class DataFetcher(AbstractDataFetcher):
205206
206207
Args:
207208
prefetch_batches: Number of batches to pre-fetch. Pre-fetching at least 1 batch is necessary to properly track
208-
whether a batch is the last one (available with :attr:`self.done`).
209+
whether a batch is the last one (available with :attr:`self.done`) under any training setup.
209210
store_on_device: Whether to store the pre-fetched batches on device.
210211
"""
211212

@@ -214,11 +215,13 @@ def __init__(self, prefetch_batches: int = 1, store_on_device: bool = True) -> N
214215
self.store_on_device = store_on_device
215216
self.batch_to_device: Callable[[Any], Any] = _no_op_batch_to_device
216217
self.batches: List[Any] = []
218+
self._has_len = False
217219

218220
def setup( # type: ignore[override]
219221
self, dataloader: Iterable, batch_to_device: Optional[Callable[[Any], Any]] = None
220222
) -> None:
221223
super().setup(dataloader)
224+
self._has_len = has_len(dataloader)
222225
if batch_to_device is not None:
223226
self.batch_to_device = batch_to_device
224227

@@ -233,6 +236,9 @@ def prefetching(self) -> None:
233236
try:
234237
self._fetch_next_batch(iterator)
235238
except StopIteration:
239+
# this would only happen when prefetch_batches > the number of batches available and makes
240+
# `fetching_function` jump directly to the empty iterator case without trying to fetch again
241+
self.done = True
236242
break
237243

238244
def fetching_function(self) -> Any:
@@ -266,6 +272,11 @@ def _fetch_next_batch(self, iterator: Iterator) -> None:
266272
start_output = self.on_fetch_start()
267273
batch = next(iterator)
268274
self.fetched += 1
275+
if not self.prefetch_batches and self._has_len:
276+
# when we don't prefetch but the dataloader is sized, we use the length for `done`
277+
dataloader = self.dataloader
278+
assert isinstance(dataloader, Sized) # `_has_len` is True
279+
self.done = self.fetched >= len(dataloader)
269280
self.on_fetch_end(batch, start_output)
270281

271282
def move_to_device(self, batch: Any) -> Any:
@@ -360,7 +371,8 @@ def training_step(self, dataloader_iter: Iterator, batch_idx: int) -> None:
360371
...
361372
"""
362373

363-
def __init__(self) -> None:
374+
def __init__(self, prefetch_batches: int = 0) -> None:
375+
# prefetch batches is not used for this class
364376
super().__init__()
365377
self.store_on_device = False
366378

tests/loops/test_loops.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -648,16 +648,12 @@ def train_dataloader(self):
648648
"ready": n_epochs,
649649
"started": n_epochs,
650650
"processed": n_epochs,
651-
# TODO: the following "-1" offset will be fixed by
652-
# https://github.com/PyTorchLightning/pytorch-lightning/pull/8578
653651
"completed": n_epochs - 1,
654652
},
655653
"current": {
656654
"ready": n_epochs,
657655
"started": n_epochs,
658656
"processed": n_epochs,
659-
# TODO: the following "-1" offset will be fixed by
660-
# https://github.com/PyTorchLightning/pytorch-lightning/pull/8578
661657
"completed": n_epochs - 1,
662658
},
663659
},
@@ -956,8 +952,6 @@ def val_dataloader(self):
956952
# totals are increased by 1 (the failed batch which never completed)
957953
expected = state_dict.copy()
958954

959-
# TODO: `is_last_batch` is not correct on reload, the next line should not be necessary
960-
expected["epoch_loop.batch_progress"]["is_last_batch"] = val_check_interval == 1.0
961955
assert state_dict_after_restart["epoch_loop.batch_progress"] == expected["epoch_loop.batch_progress"]
962956

963957
val_dl_progress = "epoch_loop.val_loop.dataloader_progress"

tests/trainer/test_dataloaders.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -516,20 +516,16 @@ def test_mixing_of_dataloader_options(tmpdir, ckpt_path):
516516
assert len(trainer.test_dataloaders) == 1
517517

518518

519-
def test_error_on_zero_len_dataloader(tmpdir):
520-
"""Test that error is raised if a zero-length dataloader is defined."""
521-
522-
class CustomBoringModel(BoringModel):
523-
def train_dataloader(self):
524-
return DataLoader(RandomDataset(32, 0))
525-
526-
model = CustomBoringModel()
519+
def test_warning_on_zero_len_dataloader(tmpdir):
520+
"""Test that a warning is raised if a zero-length dataloader is defined."""
521+
model = BoringModel()
527522
trainer = Trainer(
528523
default_root_dir=tmpdir,
529524
fast_dev_run=1,
530525
)
531-
with pytest.raises(ValueError, match="returned 0 length. .* at least 1 batch"):
532-
trainer.fit(model)
526+
dataloader = DataLoader(RandomDataset(32, 0))
527+
with pytest.warns(UserWarning, match="returned 0 length"):
528+
trainer.fit(model, dataloader)
533529

534530

535531
@RunIf(skip_windows=True)

tests/utilities/test_auto_restart.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1452,7 +1452,7 @@ def load_state_dict(self, state_dict):
14521452

14531453

14541454
class RandomFaultTolerantSampler(RandomSampler):
1455-
def __init__(self, *args, seed: int = 0, generator=None, **kwargs):
1455+
def __init__(self, *args, seed: int = 0, **kwargs):
14561456
generator = torch.Generator().manual_seed(seed)
14571457
super().__init__(*args, generator=generator, **kwargs)
14581458
self.counter = 0
@@ -1558,7 +1558,7 @@ def configure_optimizers(self):
15581558
seed_everything(42)
15591559
model = TestModel(should_fail=True)
15601560
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, val_check_interval=val_check_interval)
1561-
with suppress(CustomException):
1561+
with pytest.raises(CustomException):
15621562
trainer.fit(model)
15631563
trainer.train_dataloader = None
15641564
failed_batches = model.batches

tests/utilities/test_data.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def __iter__(self):
9393
def test_has_len():
9494
assert has_len(DataLoader(RandomDataset(1, 1)))
9595

96-
with pytest.raises(ValueError, match="`Dataloader` returned 0 length."):
96+
with pytest.warns(UserWarning, match="`DataLoader` returned 0 length."):
9797
assert has_len(DataLoader(RandomDataset(0, 0)))
9898

9999
assert not has_len(DataLoader(RandomIterableDataset(1, 1)))
@@ -112,8 +112,8 @@ def test_has_len_all_rank():
112112
trainer = Trainer(fast_dev_run=True)
113113
model = BoringModel()
114114

115-
with pytest.raises(MisconfigurationException, match="Total length of `Dataloader` across ranks is zero."):
116-
assert not has_len_all_ranks(DataLoader(RandomDataset(0, 0)), trainer.strategy, model)
115+
with pytest.warns(UserWarning, match="Total length of `DataLoader` across ranks is zero."):
116+
assert has_len_all_ranks(DataLoader(RandomDataset(0, 0)), trainer.strategy, model)
117117

118118
assert has_len_all_ranks(DataLoader(RandomDataset(1, 1)), trainer.strategy, model)
119119

0 commit comments

Comments
 (0)