Skip to content

Commit 7da931d

Browse files
authored
Support no pre-fetching (#11606)
1 parent c71a1d7 commit 7da931d

File tree

3 files changed

+43
-28
lines changed

3 files changed

+43
-28
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
7272
- Added `LightningModule.lr_scheduler_step` ([#10249](https://github.com/PyTorchLightning/pytorch-lightning/pull/10249))
7373

7474

75+
- Added support for no pre-fetching to `DataFetcher` ([#11606](https://github.com/PyTorchLightning/pytorch-lightning/pull/11606))
76+
77+
7578
- Added `opt_idx` to scheduler config if not assigned by user ([#11247](https://github.com/PyTorchLightning/pytorch-lightning/pull/11247))
7679

7780

pytorch_lightning/utilities/fetching.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -201,19 +201,15 @@ def _no_op_batch_to_device(batch: Any) -> Any:
201201

202202

203203
class DataFetcher(AbstractDataFetcher):
204-
205-
"""This class is used to control batch fetching flow. By default, the ``fetching_function`` will pre-fetch a
206-
batch in advance to detect the end of the iteration.
204+
"""This class is used to control batch fetching flow.
207205
208206
Args:
209-
prefetch_batches: Number of batches to be pre-fetched. Lightning will pre-fetch
210-
at least 1 batch for tracking the latest batch.
207+
prefetch_batches: Number of batches to pre-fetch. Pre-fetching at least 1 batch is necessary to properly track
208+
whether a batch is the last one (available with :attr:`self.done`).
211209
store_on_device: Whether to store the pre-fetched batches on device.
212210
"""
213211

214212
def __init__(self, prefetch_batches: int = 1, store_on_device: bool = True) -> None:
215-
if prefetch_batches < 1:
216-
raise MisconfigurationException("`prefetch_batches` should at least be 1.")
217213
super().__init__(prefetch_batches=prefetch_batches)
218214
self.store_on_device = store_on_device
219215
self.batch_to_device: Callable[[Any], Any] = _no_op_batch_to_device
@@ -240,19 +236,31 @@ def prefetching(self) -> None:
240236
break
241237

242238
def fetching_function(self) -> Tuple[Any, bool]:
239+
assert self.dataloader_iter is not None
243240
if self.batches:
241+
# there are pre-fetched batches already from a previous `prefetching` call.
242+
# consume one
244243
batch = self.batches.pop(0)
245-
else:
246-
# empty iterator, no prefetching done
247-
raise StopIteration
248-
if not self.done:
249-
assert self.dataloader_iter is not None
250244
try:
245+
# refill the consumed batch
251246
self._fetch_next_batch(self.dataloader_iter)
252247
except StopIteration:
248+
# no more batches to fetch. we are done only if all pre-fetched batches were returned
249+
self.done = not self.batches
250+
elif not self.done:
251+
# this will run only when no pre-fetching was done.
252+
try:
253+
self._fetch_next_batch(self.dataloader_iter)
254+
# consume the batch we just fetched
255+
batch = self.batches.pop(0)
256+
except StopIteration as e:
253257
self.done = True
258+
raise e
259+
else:
260+
# the iterator is empty
261+
raise StopIteration
254262
self.wait()
255-
return self.move_to_device(batch), len(self.batches) == 0
263+
return self.move_to_device(batch), self.done
256264

257265
def _fetch_next_batch(self, iterator: Iterator) -> None:
258266
start_output = self.on_fetch_start()

tests/utilities/test_fetching.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -40,28 +40,32 @@ def __iter__(self):
4040
yield 2
4141
yield 3
4242

43-
for prefetch_batches in range(1, 5):
43+
for prefetch_batches in range(5):
44+
iterator = DataFetcher(prefetch_batches=prefetch_batches)
45+
assert iterator.prefetch_batches == prefetch_batches
46+
4447
if use_combined_loader:
4548
loader = CombinedLoader([DataLoader(IterDataset()), DataLoader(IterDataset())])
46-
expected = [
47-
([tensor([1]), tensor([1])], False),
48-
([tensor([2]), tensor([2])], False),
49-
([tensor([3]), tensor([3])], True),
50-
]
5149
else:
5250
loader = DataLoader(IterDataset())
53-
expected = [(1, False), (2, False), (3, True)]
54-
iterator = DataFetcher(prefetch_batches=prefetch_batches)
55-
assert iterator.prefetch_batches == prefetch_batches
5651
iterator.setup(loader)
5752

5853
def generate():
59-
generated = []
60-
for idx, data in enumerate(iterator, prefetch_batches + 1):
61-
assert iterator.fetched == 3 if iterator.done else idx
62-
generated.append(data)
54+
generated = [(iterator.fetched, *data) for i, data in enumerate(iterator, prefetch_batches + 1)]
55+
assert iterator.fetched == 3
56+
assert iterator.done
6357
return generated
6458

59+
is_last_batch = [False, False, prefetch_batches > 0]
60+
fetched = list(range(prefetch_batches + 1, 4))
61+
fetched += [3] * (3 - len(fetched))
62+
if use_combined_loader:
63+
batches = [[tensor(1), tensor(1)], [tensor(2), tensor(2)], [tensor(3), tensor(3)]]
64+
else:
65+
batches = [1, 2, 3]
66+
expected = list(zip(fetched, batches, is_last_batch))
67+
assert len(expected) == 3
68+
6569
assert generate() == expected
6670
# validate reset works properly.
6771
assert generate() == expected
@@ -71,9 +75,9 @@ class EmptyIterDataset(IterableDataset):
7175
def __iter__(self):
7276
return iter([])
7377

74-
dataloader = DataLoader(EmptyIterDataset())
78+
loader = DataLoader(EmptyIterDataset())
7579
iterator = DataFetcher()
76-
iterator.setup(dataloader)
80+
iterator.setup(loader)
7781
assert not list(iterator)
7882

7983

0 commit comments

Comments
 (0)