Skip to content

Commit b7331d8

Browse files
rohitgr7awaelchli
andauthored
Disable eval dataloaders replacement during overfitting (#10877)
Co-authored-by: Adrian Wälchli <[email protected]>
1 parent ff20af4 commit b7331d8

File tree

9 files changed

+163
-201
lines changed

9 files changed

+163
-201
lines changed

docs/source/common/debugging.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,10 @@ argument of :class:`~pytorch_lightning.trainer.trainer.Trainer`)
7979

8080
.. testcode::
8181

82-
# use only 1% of training data (and use the same training dataloader (with shuffle off) in val and test)
82+
# use only 1% of training data (and turn off validation)
8383
trainer = Trainer(overfit_batches=0.01)
8484

85-
# similar, but with a fixed 10 batches no matter the size of the dataset
85+
# similar, but with a fixed 10 batches
8686
trainer = Trainer(overfit_batches=10)
8787

8888
With this flag, the train, val, and test sets will all be the same train set. We will also replace the sampler

docs/source/common/trainer.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1074,7 +1074,7 @@ overfit_batches
10741074

10751075
|
10761076
1077-
Uses this much data of the training set. If nonzero, will use the same training set for validation and testing.
1077+
Uses this much data of the training set. If nonzero, will turn off validation.
10781078
If the training dataloaders have `shuffle=True`, Lightning will automatically disable it.
10791079

10801080
Useful for quickly debugging or trying to overfit on purpose.
@@ -1084,7 +1084,7 @@ Useful for quickly debugging or trying to overfit on purpose.
10841084
# default used by the Trainer
10851085
trainer = Trainer(overfit_batches=0.0)
10861086

1087-
# use only 1% of the train set (and use the train set for val and test)
1087+
# use only 1% of the train set
10881088
trainer = Trainer(overfit_batches=0.01)
10891089

10901090
# overfit on 10 of the same batches

docs/source/guides/speed.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ If you don't want to check 100% of the training/validation/test set set these fl
336336

337337
If you also pass ``shuffle=True`` to the dataloader, a different random subset of your dataset will be used for each epoch; otherwise the same subset will be used for all epochs.
338338

339-
.. note:: ``limit_train_batches``, ``limit_val_batches`` and ``limit_test_batches`` will be overwritten by ``overfit_batches`` if ``overfit_batches`` > 0. ``limit_val_batches`` will be ignored if ``fast_dev_run=True``.
339+
.. note:: ``limit_train_batches`` will be overwritten by ``overfit_batches`` if ``overfit_batches > 0`` and will turn off validation.
340340

341341
.. note:: If you set ``limit_val_batches=0``, validation will be disabled.
342342

legacy/simple_classif_training.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ def _split_data(self):
5555
self.x_train, self.x_test, self.y_train, self.y_test = train_test_split(
5656
self._x, self._y, test_size=0.20, random_state=42
5757
)
58+
self.x_train, self.x_predict, self.y_train, self.y_predict = train_test_split(
59+
self._x, self._y, test_size=0.20, random_state=42
60+
)
5861
self.x_train, self.x_valid, self.y_train, self.y_valid = train_test_split(
5962
self.x_train, self.y_train, test_size=0.40, random_state=42
6063
)
@@ -76,6 +79,11 @@ def test_dataloader(self):
7679
SklearnDataset(self.x_test, self.y_test, self._x_type, self._y_type), batch_size=self.batch_size
7780
)
7881

82+
def predict_dataloader(self):
83+
return DataLoader(
84+
SklearnDataset(self.x_predict, self.y_predict, self._x_type, self._y_type), batch_size=self.batch_size
85+
)
86+
7987

8088
class ClassifDataModule(SklearnDataModule):
8189
def __init__(self, num_features=24, length=6000, num_classes=3, batch_size=128):

pytorch_lightning/trainer/data_loading.py

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import multiprocessing
1515
import os
1616
from abc import ABC
17-
from copy import deepcopy
1817
from typing import Any, Callable, Collection, List, Optional, Tuple, Union
1918

2019
from torch.utils.data import DataLoader, RandomSampler, Sampler, SequentialSampler
@@ -298,28 +297,15 @@ def _reset_eval_dataloader(
298297
if not isinstance(dataloaders, list):
299298
dataloaders = [dataloaders]
300299

301-
# when overfitting, use the training loader as val and test
302-
# duplicate it the numb of times needed to match the train loaders
303-
if self.overfit_batches > 0:
304-
train_dataloader = self.request_dataloader(RunningStage.TRAINING, model=model)
305-
dataloaders = [deepcopy(train_dataloader) for _ in range(len(dataloaders))]
306-
307300
for loader_i in range(len(dataloaders)):
308301
loader = dataloaders[loader_i]
309302

310303
if hasattr(loader, "sampler") and not isinstance(loader.sampler, SequentialSampler):
311-
# when overfitting, the dataloader should not have sampler
312-
if self.overfit_batches > 0 and mode.evaluating:
313-
rank_zero_warn(
314-
"You requested to overfit but enabled val/test dataloader shuffling."
315-
" We are turning it off for you."
316-
)
317-
dataloaders[loader_i] = _update_dataloader(loader, SequentialSampler(loader.dataset), mode=mode)
318-
else:
319-
rank_zero_warn(
320-
f"Your `{mode.dataloader_prefix}_dataloader` has `shuffle=True`,"
321-
"it is strongly recommended that you turn this off for val/test/predict dataloaders."
322-
)
304+
rank_zero_warn(
305+
f"Your `{mode.dataloader_prefix}_dataloader` has `shuffle=True`,"
306+
" it is strongly recommended that you turn this off for val/test/predict dataloaders.",
307+
category=PossibleUserWarning,
308+
)
323309

324310
if any(dl is None for dl in dataloaders):
325311
rank_zero_warn("One of given dataloaders is None and it will be skipped.")
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
17+
from pytorch_lightning import Trainer
18+
from pytorch_lightning.trainer.states import RunningStage
19+
from tests.helpers.boring_model import BoringModel
20+
21+
22+
def test_num_dataloader_batches(tmpdir):
23+
"""Tests that the correct number of batches are allocated."""
24+
# when we have fewer batches in the dataloader we should use those instead of the limit
25+
model = BoringModel()
26+
trainer = Trainer(limit_val_batches=100, limit_train_batches=100, max_epochs=1, default_root_dir=tmpdir)
27+
trainer.fit(model)
28+
29+
assert len(model.train_dataloader()) == 64
30+
assert len(model.val_dataloader()) == 64
31+
assert isinstance(trainer.num_val_batches, list)
32+
assert trainer.num_val_batches[0] == 64
33+
assert trainer.num_training_batches == 64
34+
35+
# when we have more batches in the dataloader we should limit them
36+
model = BoringModel()
37+
trainer = Trainer(limit_val_batches=7, limit_train_batches=7, max_epochs=1, default_root_dir=tmpdir)
38+
trainer.fit(model)
39+
40+
assert len(model.train_dataloader()) == 64
41+
assert len(model.val_dataloader()) == 64
42+
assert isinstance(trainer.num_val_batches, list)
43+
assert trainer.num_val_batches[0] == 7
44+
assert trainer.num_training_batches == 7
45+
46+
47+
@pytest.mark.parametrize(
48+
["stage", "mode"],
49+
[
50+
(RunningStage.VALIDATING, "val"),
51+
(RunningStage.TESTING, "test"),
52+
(RunningStage.PREDICTING, "predict"),
53+
],
54+
)
55+
@pytest.mark.parametrize("limit_batches", [0.1, 10])
56+
def test_eval_limit_batches(stage, mode, limit_batches):
57+
limit_eval_batches = f"limit_{mode}_batches"
58+
dl_hook = f"{mode}_dataloader"
59+
model = BoringModel()
60+
eval_loader = getattr(model, dl_hook)()
61+
62+
trainer = Trainer(**{limit_eval_batches: limit_batches})
63+
model.trainer = trainer
64+
trainer._data_connector.attach_dataloaders(model)
65+
loader_num_batches, dataloaders = trainer._reset_eval_dataloader(stage, model=model)
66+
expected_batches = int(limit_batches * len(eval_loader)) if isinstance(limit_batches, float) else limit_batches
67+
assert loader_num_batches[0] == expected_batches
68+
assert len(dataloaders[0]) == len(eval_loader)

tests/trainer/flags/test_overfit_batches.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,13 @@
1313
# limitations under the License.
1414
import pytest
1515
import torch
16-
from torch.utils.data.sampler import Sampler, SequentialSampler
16+
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, Sampler, SequentialSampler
1717

18+
from legacy.simple_classif_training import ClassifDataModule, ClassificationModel
1819
from pytorch_lightning import Trainer
20+
from pytorch_lightning.trainer.states import RunningStage
1921
from tests.helpers.boring_model import BoringModel, RandomDataset
22+
from tests.helpers.runif import RunIf
2023

2124

2225
@pytest.mark.parametrize("overfit_batches", [1, 2, 0.1, 0.25, 1.0])
@@ -62,3 +65,75 @@ def train_dataloader(self):
6265
trainer.fit(model)
6366

6467
assert isinstance(trainer.train_dataloader.loaders.sampler, SequentialSampler)
68+
69+
70+
@pytest.mark.parametrize(
71+
"stage,mode",
72+
[(RunningStage.VALIDATING, "val"), (RunningStage.TESTING, "test"), (RunningStage.PREDICTING, "predict")],
73+
)
74+
@pytest.mark.parametrize("overfit_batches", [0.11, 4])
75+
def test_overfit_batch_limits_eval(stage, mode, overfit_batches):
76+
model = ClassificationModel()
77+
dm = ClassifDataModule()
78+
eval_loader = getattr(dm, f"{mode}_dataloader")()
79+
trainer = Trainer(overfit_batches=overfit_batches)
80+
model.trainer = trainer
81+
trainer._data_connector.attach_datamodule(model, datamodule=dm)
82+
83+
loader_num_batches, dataloaders = trainer._reset_eval_dataloader(stage, model=model)
84+
if stage == RunningStage.VALIDATING:
85+
assert loader_num_batches[0] == 0
86+
else:
87+
assert loader_num_batches[0] == len(eval_loader)
88+
assert isinstance(dataloaders[0].sampler, SequentialSampler)
89+
90+
91+
@pytest.mark.parametrize("overfit_batches", [0.11, 4])
92+
def test_overfit_batch_limits_train(overfit_batches):
93+
model = ClassificationModel()
94+
dm = ClassifDataModule()
95+
96+
# original train loader which should be replaced in all methods
97+
train_loader = dm.train_dataloader()
98+
assert isinstance(train_loader.sampler, RandomSampler)
99+
100+
# Create a reference train dataloader without shuffling.
101+
train_loader = DataLoader(dm.train_dataloader().dataset, shuffle=False)
102+
(xa, ya) = next(iter(train_loader))
103+
train_loader = DataLoader(dm.train_dataloader().dataset, shuffle=True)
104+
full_train_samples = len(train_loader)
105+
106+
# set the model loaders
107+
model.train_dataloader = lambda: train_loader
108+
109+
# test train loader applies correct limits
110+
trainer = Trainer(overfit_batches=overfit_batches)
111+
model.trainer = trainer
112+
trainer._data_connector.attach_dataloaders(model=model)
113+
trainer.reset_train_dataloader(model)
114+
expected_batches = (
115+
int(overfit_batches * full_train_samples) if isinstance(overfit_batches, float) else overfit_batches
116+
)
117+
assert trainer.num_training_batches == expected_batches
118+
119+
# make sure the loaders are the same
120+
(xb, yb) = next(iter(trainer.train_dataloader))
121+
assert torch.eq(xa, xb).all()
122+
assert torch.eq(ya, yb).all()
123+
124+
125+
@RunIf(skip_windows=True)
126+
def test_distributed_sampler_with_overfit_batches():
127+
model = BoringModel()
128+
trainer = Trainer(
129+
overfit_batches=1,
130+
strategy="ddp_spawn",
131+
num_processes=2,
132+
)
133+
model.trainer = trainer
134+
trainer.model = model
135+
trainer._data_connector.attach_dataloaders(model)
136+
trainer.reset_train_dataloader()
137+
train_sampler = trainer.train_dataloader.loaders.sampler
138+
assert isinstance(train_sampler, DistributedSampler)
139+
assert train_sampler.shuffle is False

tests/trainer/test_data_loading.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@
1616
from re import escape
1717

1818
import pytest
19-
from torch.utils.data import DataLoader, DistributedSampler
20-
from torch.utils.data.sampler import BatchSampler, Sampler, SequentialSampler
19+
from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, Sampler, SequentialSampler
2120

2221
from pytorch_lightning import Trainer
2322
from pytorch_lightning.utilities.data import _update_dataloader

0 commit comments

Comments
 (0)