Skip to content

Disable eval dataloaders replacement during overfitting #10877

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Dec 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/common/debugging.rst
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,10 @@ argument of :class:`~pytorch_lightning.trainer.trainer.Trainer`)

.. testcode::

# use only 1% of training data (and use the same training dataloader (with shuffle off) in val and test)
# use only 1% of training data (and turn off validation)
trainer = Trainer(overfit_batches=0.01)

# similar, but with a fixed 10 batches no matter the size of the dataset
# similar, but with a fixed 10 batches
trainer = Trainer(overfit_batches=10)

With this flag, the train, val, and test sets will all be the same train set. We will also replace the sampler
Expand Down
4 changes: 2 additions & 2 deletions docs/source/common/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1074,7 +1074,7 @@ overfit_batches

|

Uses this much data of the training set. If nonzero, will use the same training set for validation and testing.
Uses this much data of the training set. If nonzero, will turn off validation.
If the training dataloaders have `shuffle=True`, Lightning will automatically disable it.

Useful for quickly debugging or trying to overfit on purpose.
Expand All @@ -1084,7 +1084,7 @@ Useful for quickly debugging or trying to overfit on purpose.
# default used by the Trainer
trainer = Trainer(overfit_batches=0.0)

# use only 1% of the train set (and use the train set for val and test)
# use only 1% of the train set
trainer = Trainer(overfit_batches=0.01)

# overfit on 10 of the same batches
Expand Down
2 changes: 1 addition & 1 deletion docs/source/guides/speed.rst
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ If you don't want to check 100% of the training/validation/test set set these fl

If you also pass ``shuffle=True`` to the dataloader, a different random subset of your dataset will be used for each epoch; otherwise the same subset will be used for all epochs.

.. note:: ``limit_train_batches``, ``limit_val_batches`` and ``limit_test_batches`` will be overwritten by ``overfit_batches`` if ``overfit_batches`` > 0. ``limit_val_batches`` will be ignored if ``fast_dev_run=True``.
.. note:: ``limit_train_batches`` will be overwritten by ``overfit_batches`` if ``overfit_batches > 0`` and will turn off validation.

.. note:: If you set ``limit_val_batches=0``, validation will be disabled.

Expand Down
8 changes: 8 additions & 0 deletions legacy/simple_classif_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ def _split_data(self):
self.x_train, self.x_test, self.y_train, self.y_test = train_test_split(
self._x, self._y, test_size=0.20, random_state=42
)
self.x_train, self.x_predict, self.y_train, self.y_predict = train_test_split(
self._x, self._y, test_size=0.20, random_state=42
)
self.x_train, self.x_valid, self.y_train, self.y_valid = train_test_split(
self.x_train, self.y_train, test_size=0.40, random_state=42
)
Expand All @@ -76,6 +79,11 @@ def test_dataloader(self):
SklearnDataset(self.x_test, self.y_test, self._x_type, self._y_type), batch_size=self.batch_size
)

def predict_dataloader(self):
return DataLoader(
SklearnDataset(self.x_predict, self.y_predict, self._x_type, self._y_type), batch_size=self.batch_size
)


class ClassifDataModule(SklearnDataModule):
def __init__(self, num_features=24, length=6000, num_classes=3, batch_size=128):
Expand Down
24 changes: 5 additions & 19 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import multiprocessing
import os
from abc import ABC
from copy import deepcopy
from typing import Any, Callable, Collection, List, Optional, Tuple, Union

from torch.utils.data import DataLoader, RandomSampler, Sampler, SequentialSampler
Expand Down Expand Up @@ -293,28 +292,15 @@ def _reset_eval_dataloader(
if not isinstance(dataloaders, list):
dataloaders = [dataloaders]

# when overfitting, use the training loader as val and test
# duplicate it the numb of times needed to match the train loaders
if self.overfit_batches > 0:
train_dataloader = self.request_dataloader(RunningStage.TRAINING, model=model)
dataloaders = [deepcopy(train_dataloader) for _ in range(len(dataloaders))]

for loader_i in range(len(dataloaders)):
loader = dataloaders[loader_i]

if hasattr(loader, "sampler") and not isinstance(loader.sampler, SequentialSampler):
# when overfitting, the dataloader should not have sampler
if self.overfit_batches > 0 and mode.evaluating:
rank_zero_warn(
"You requested to overfit but enabled val/test dataloader shuffling."
" We are turning it off for you."
)
dataloaders[loader_i] = _update_dataloader(loader, SequentialSampler(loader.dataset), mode=mode)
else:
rank_zero_warn(
f"Your `{mode.dataloader_prefix}_dataloader` has `shuffle=True`,"
"it is strongly recommended that you turn this off for val/test/predict dataloaders."
)
rank_zero_warn(
f"Your `{mode.dataloader_prefix}_dataloader` has `shuffle=True`,"
" it is strongly recommended that you turn this off for val/test/predict dataloaders.",
category=PossibleUserWarning,
)

if any(dl is None for dl in dataloaders):
rank_zero_warn("One of given dataloaders is None and it will be skipped.")
Expand Down
68 changes: 68 additions & 0 deletions tests/trainer/flags/test_limit_batches.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

from pytorch_lightning import Trainer
from pytorch_lightning.trainer.states import RunningStage
from tests.helpers.boring_model import BoringModel


def test_num_dataloader_batches(tmpdir):
"""Tests that the correct number of batches are allocated."""
# when we have fewer batches in the dataloader we should use those instead of the limit
model = BoringModel()
trainer = Trainer(limit_val_batches=100, limit_train_batches=100, max_epochs=1, default_root_dir=tmpdir)
trainer.fit(model)

assert len(model.train_dataloader()) == 64
assert len(model.val_dataloader()) == 64
assert isinstance(trainer.num_val_batches, list)
assert trainer.num_val_batches[0] == 64
assert trainer.num_training_batches == 64

# when we have more batches in the dataloader we should limit them
model = BoringModel()
trainer = Trainer(limit_val_batches=7, limit_train_batches=7, max_epochs=1, default_root_dir=tmpdir)
trainer.fit(model)

assert len(model.train_dataloader()) == 64
assert len(model.val_dataloader()) == 64
assert isinstance(trainer.num_val_batches, list)
assert trainer.num_val_batches[0] == 7
assert trainer.num_training_batches == 7


@pytest.mark.parametrize(
["stage", "mode"],
[
(RunningStage.VALIDATING, "val"),
(RunningStage.TESTING, "test"),
(RunningStage.PREDICTING, "predict"),
],
)
@pytest.mark.parametrize("limit_batches", [0.1, 10])
def test_eval_limit_batches(stage, mode, limit_batches):
limit_eval_batches = f"limit_{mode}_batches"
dl_hook = f"{mode}_dataloader"
model = BoringModel()
eval_loader = getattr(model, dl_hook)()

trainer = Trainer(**{limit_eval_batches: limit_batches})
model.trainer = trainer
trainer._data_connector.attach_dataloaders(model)
loader_num_batches, dataloaders = trainer._reset_eval_dataloader(stage, model=model)
expected_batches = int(limit_batches * len(eval_loader)) if isinstance(limit_batches, float) else limit_batches
assert loader_num_batches[0] == expected_batches
assert len(dataloaders[0]) == len(eval_loader)
77 changes: 76 additions & 1 deletion tests/trainer/flags/test_overfit_batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@
# limitations under the License.
import pytest
import torch
from torch.utils.data.sampler import Sampler, SequentialSampler
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, Sampler, SequentialSampler

from legacy.simple_classif_training import ClassifDataModule, ClassificationModel
from pytorch_lightning import Trainer
from pytorch_lightning.trainer.states import RunningStage
from tests.helpers.boring_model import BoringModel, RandomDataset
from tests.helpers.runif import RunIf


@pytest.mark.parametrize("overfit_batches", [1, 2, 0.1, 0.25, 1.0])
Expand Down Expand Up @@ -62,3 +65,75 @@ def train_dataloader(self):
trainer.fit(model)

assert isinstance(trainer.train_dataloader.loaders.sampler, SequentialSampler)


@pytest.mark.parametrize(
"stage,mode",
[(RunningStage.VALIDATING, "val"), (RunningStage.TESTING, "test"), (RunningStage.PREDICTING, "predict")],
)
@pytest.mark.parametrize("overfit_batches", [0.11, 4])
def test_overfit_batch_limits_eval(stage, mode, overfit_batches):
model = ClassificationModel()
dm = ClassifDataModule()
eval_loader = getattr(dm, f"{mode}_dataloader")()
trainer = Trainer(overfit_batches=overfit_batches)
model.trainer = trainer
trainer._data_connector.attach_datamodule(model, datamodule=dm)

loader_num_batches, dataloaders = trainer._reset_eval_dataloader(stage, model=model)
if stage == RunningStage.VALIDATING:
assert loader_num_batches[0] == 0
else:
assert loader_num_batches[0] == len(eval_loader)
assert isinstance(dataloaders[0].sampler, SequentialSampler)


@pytest.mark.parametrize("overfit_batches", [0.11, 4])
def test_overfit_batch_limits_train(overfit_batches):
model = ClassificationModel()
dm = ClassifDataModule()

# original train loader which should be replaced in all methods
train_loader = dm.train_dataloader()
assert isinstance(train_loader.sampler, RandomSampler)

# Create a reference train dataloader without shuffling.
train_loader = DataLoader(dm.train_dataloader().dataset, shuffle=False)
(xa, ya) = next(iter(train_loader))
train_loader = DataLoader(dm.train_dataloader().dataset, shuffle=True)
full_train_samples = len(train_loader)

# set the model loaders
model.train_dataloader = lambda: train_loader

# test train loader applies correct limits
trainer = Trainer(overfit_batches=overfit_batches)
model.trainer = trainer
trainer._data_connector.attach_dataloaders(model=model)
trainer.reset_train_dataloader(model)
expected_batches = (
int(overfit_batches * full_train_samples) if isinstance(overfit_batches, float) else overfit_batches
)
assert trainer.num_training_batches == expected_batches

# make sure the loaders are the same
(xb, yb) = next(iter(trainer.train_dataloader))
assert torch.eq(xa, xb).all()
assert torch.eq(ya, yb).all()


@RunIf(skip_windows=True)
def test_distributed_sampler_with_overfit_batches():
model = BoringModel()
trainer = Trainer(
overfit_batches=1,
strategy="ddp_spawn",
num_processes=2,
)
model.trainer = trainer
trainer.model = model
trainer._data_connector.attach_dataloaders(model)
trainer.reset_train_dataloader()
train_sampler = trainer.train_dataloader.loaders.sampler
assert isinstance(train_sampler, DistributedSampler)
assert train_sampler.shuffle is False
3 changes: 1 addition & 2 deletions tests/trainer/test_data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
from re import escape

import pytest
from torch.utils.data import DataLoader, DistributedSampler
from torch.utils.data.sampler import BatchSampler, Sampler, SequentialSampler
from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, Sampler, SequentialSampler

from pytorch_lightning import Trainer
from pytorch_lightning.utilities.data import _update_dataloader
Expand Down
Loading