Skip to content

Disable eval dataloaders replacement during overfitting #10877

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Dec 3, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/common/debugging.rst
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,10 @@ argument of :class:`~pytorch_lightning.trainer.trainer.Trainer`)

.. testcode::

# use only 1% of training data (and use the same training dataloader (with shuffle off) in val and test)
# use only 1% of training data (and turn off validation)
trainer = Trainer(overfit_batches=0.01)

# similar, but with a fixed 10 batches no matter the size of the dataset
# similar, but with a fixed 10 batches
trainer = Trainer(overfit_batches=10)

With this flag, the train, val, and test sets will all be the same train set. We will also replace the sampler
Expand Down
4 changes: 2 additions & 2 deletions docs/source/common/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1074,7 +1074,7 @@ overfit_batches

|

Uses this much data of the training set. If nonzero, will use the same training set for validation and testing.
Uses this much data of the training set. If nonzero, will turn off validation.
If the training dataloaders have `shuffle=True`, Lightning will automatically disable it.

Useful for quickly debugging or trying to overfit on purpose.
Expand All @@ -1084,7 +1084,7 @@ Useful for quickly debugging or trying to overfit on purpose.
# default used by the Trainer
trainer = Trainer(overfit_batches=0.0)

# use only 1% of the train set (and use the train set for val and test)
# use only 1% of the train set
trainer = Trainer(overfit_batches=0.01)

# overfit on 10 of the same batches
Expand Down
2 changes: 1 addition & 1 deletion docs/source/guides/speed.rst
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ If you don't want to check 100% of the training/validation/test set set these fl

If you also pass ``shuffle=True`` to the dataloader, a different random subset of your dataset will be used for each epoch; otherwise the same subset will be used for all epochs.

.. note:: ``limit_train_batches``, ``limit_val_batches`` and ``limit_test_batches`` will be overwritten by ``overfit_batches`` if ``overfit_batches`` > 0. ``limit_val_batches`` will be ignored if ``fast_dev_run=True``.
.. note:: ``limit_train_batches`` will be overwritten by ``overfit_batches`` if ``overfit_batches > 0`` and will turn off validation.

.. note:: If you set ``limit_val_batches=0``, validation will be disabled.

Expand Down
24 changes: 5 additions & 19 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import multiprocessing
import os
from abc import ABC
from copy import deepcopy
from typing import Any, Callable, Collection, List, Optional, Tuple, Union

from torch.utils.data import DataLoader, RandomSampler, Sampler, SequentialSampler
Expand Down Expand Up @@ -293,28 +292,15 @@ def _reset_eval_dataloader(
if not isinstance(dataloaders, list):
dataloaders = [dataloaders]

# when overfitting, use the training loader as val and test
# duplicate it the numb of times needed to match the train loaders
if self.overfit_batches > 0:
train_dataloader = self.request_dataloader(RunningStage.TRAINING, model=model)
dataloaders = [deepcopy(train_dataloader) for _ in range(len(dataloaders))]

for loader_i in range(len(dataloaders)):
loader = dataloaders[loader_i]

if hasattr(loader, "sampler") and not isinstance(loader.sampler, SequentialSampler):
# when overfitting, the dataloader should not have sampler
if self.overfit_batches > 0 and mode.evaluating:
rank_zero_warn(
"You requested to overfit but enabled val/test dataloader shuffling."
" We are turning it off for you."
)
dataloaders[loader_i] = _update_dataloader(loader, SequentialSampler(loader.dataset), mode=mode)
else:
rank_zero_warn(
f"Your `{mode.dataloader_prefix}_dataloader` has `shuffle=True`,"
"it is strongly recommended that you turn this off for val/test/predict dataloaders."
)
rank_zero_warn(
f"Your `{mode.dataloader_prefix}_dataloader` has `shuffle=True`,"
" it is strongly recommended that you turn this off for val/test/predict dataloaders.",
category=PossibleUserWarning,
)

if any(dl is None for dl in dataloaders):
rank_zero_warn("One of given dataloaders is None and it will be skipped.")
Expand Down
75 changes: 75 additions & 0 deletions tests/trainer/flags/test_limit_batches.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

from pytorch_lightning import Trainer
from pytorch_lightning.trainer.states import RunningStage
from tests.helpers.boring_model import BoringModel


def test_num_dataloader_batches(tmpdir):
"""Tests that the correct number of batches are allocated."""
# when we have fewer batches in the dataloader we should use those instead of the limit
model = BoringModel()
trainer = Trainer(limit_val_batches=100, limit_train_batches=100, max_epochs=1, default_root_dir=tmpdir)
trainer.fit(model)

assert len(model.train_dataloader()) == 64
assert len(model.val_dataloader()) == 64
assert isinstance(trainer.num_val_batches, list)
assert trainer.num_val_batches[0] == 64
assert trainer.num_training_batches == 64

# when we have more batches in the dataloader we should limit them
model = BoringModel()
trainer = Trainer(limit_val_batches=7, limit_train_batches=7, max_epochs=1, default_root_dir=tmpdir)
trainer.fit(model)

assert len(model.train_dataloader()) == 64
assert len(model.val_dataloader()) == 64
assert isinstance(trainer.num_val_batches, list)
assert trainer.num_val_batches[0] == 7
assert trainer.num_training_batches == 7


@pytest.mark.parametrize(
["stage", "mode"],
[
(RunningStage.VALIDATING, "val"),
(RunningStage.TESTING, "test"),
(RunningStage.PREDICTING, "predict"),
],
)
def test_eval_limit_batches(stage, mode):
limit_eval_batches = f"limit_{mode}_batches"
dl_hook = f"{mode}_dataloader"
model = BoringModel()
eval_loader = getattr(model, dl_hook)()

limit_batches = 0.1
trainer = Trainer(**{limit_eval_batches: limit_batches})
model.trainer = trainer
trainer._data_connector.attach_dataloaders(model)
loader_num_batches, dataloaders = trainer._reset_eval_dataloader(stage, model=model)
assert loader_num_batches[0] == int(limit_batches * len(eval_loader))
assert len(dataloaders[0]) == len(eval_loader)

limit_batches = 10
trainer = Trainer(**{limit_eval_batches: limit_batches})
model.trainer = trainer
trainer._data_connector.attach_dataloaders(model)
loader_num_batches, dataloaders = trainer._reset_eval_dataloader(stage, model=model)
assert loader_num_batches[0] == limit_batches
assert len(dataloaders[0]) == len(eval_loader)
105 changes: 104 additions & 1 deletion tests/trainer/flags/test_overfit_batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
# limitations under the License.
import pytest
import torch
from torch.utils.data.sampler import Sampler, SequentialSampler
from torch.utils.data import DataLoader, RandomSampler, Sampler, SequentialSampler

from legacy.simple_classif_training import ClassifDataModule, ClassificationModel
from pytorch_lightning import Trainer
from pytorch_lightning.trainer.states import RunningStage
from tests.helpers.boring_model import BoringModel, RandomDataset


Expand Down Expand Up @@ -62,3 +64,104 @@ def train_dataloader(self):
trainer.fit(model)

assert isinstance(trainer.train_dataloader.loaders.sampler, SequentialSampler)


def test_overfit_batch_limits(tmpdir):
# ------------------------------------------------------
# Make sure shuffle is correct across loaders initially
# ------------------------------------------------------
model = ClassificationModel()
dm = ClassifDataModule()

# original train loader which should be replaced in all methods
train_loader = dm.train_dataloader()

# make sure the val and tests are not shuffled
assert isinstance(train_loader.sampler, RandomSampler)
assert isinstance(dm.val_dataloader().sampler, SequentialSampler)
assert isinstance(dm.test_dataloader().sampler, SequentialSampler)

# ------------------------------------------------------
# get the training loader and batch
# ------------------------------------------------------
# Create a reference train dataloader without shuffling.
train_loader = DataLoader(dm.train_dataloader().dataset, shuffle=False)
(xa, ya) = next(iter(train_loader))
train_loader = DataLoader(dm.train_dataloader().dataset, shuffle=True)
full_train_samples = len(train_loader)
num_train_samples = int(0.11 * full_train_samples)

# ------------------------------------------------------
# set VAL and Test loaders
# ------------------------------------------------------
val_loader = DataLoader(dm.val_dataloader().dataset, shuffle=False)
test_loader = DataLoader(dm.test_dataloader().dataset, shuffle=False)

# set the model loaders
model.train_dataloader = lambda: train_loader
model.val_dataloader = lambda: val_loader
model.test_dataloader = lambda: test_loader

# ------------------------------------------------------
# test train loader applies correct limits
# ------------------------------------------------------
trainer = Trainer(overfit_batches=4)
model.trainer = trainer
trainer._data_connector.attach_dataloaders(model=model)
trainer.reset_train_dataloader(model)
assert trainer.num_training_batches == 4

# make sure the loaders are the same
(xb, yb) = next(iter(trainer.train_dataloader))
assert torch.eq(xa, xb).all()
assert torch.eq(ya, yb).all()

trainer = Trainer(overfit_batches=0.11)
model.trainer = trainer
trainer._data_connector.attach_dataloaders(model=model)
trainer.reset_train_dataloader(model)
# The dataloader should have been overwritten with a Sequential sampler.
assert trainer.train_dataloader is not train_loader
assert trainer.num_training_batches == num_train_samples

# make sure the loaders are the same
(xb, yb) = next(iter(trainer.train_dataloader))
assert torch.eq(xa, xb).all()
assert torch.eq(ya, yb).all()

# ------------------------------------------------------
# run tests for both val and test
# ------------------------------------------------------
for split in (RunningStage.VALIDATING, RunningStage.TESTING):

# ------------------------------------------------------
# test overfit_batches as percent
# ------------------------------------------------------
trainer = Trainer(overfit_batches=0.11)
trainer._data_connector.attach_dataloaders(model)
loader_num_batches, _ = trainer._reset_eval_dataloader(split, model=model)
if split == RunningStage.VALIDATING:
assert loader_num_batches[0] == 0
else:
assert loader_num_batches[0] == len(test_loader)

# ------------------------------------------------------
# test overfit_batches as int
# ------------------------------------------------------
trainer = Trainer(overfit_batches=1)
trainer._data_connector.attach_dataloaders(model)
loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model)
if split == RunningStage.VALIDATING:
assert loader_num_batches[0] == 0
else:
assert loader_num_batches[0] == len(test_loader)
# make sure we turned off shuffle for the user
assert isinstance(dataloaders[0].sampler, SequentialSampler)

trainer = Trainer(overfit_batches=5)
trainer._data_connector.attach_dataloaders(model)
loader_num_batches, _ = trainer._reset_eval_dataloader(split, model=model)
if split == RunningStage.VALIDATING:
assert loader_num_batches[0] == 0
else:
assert loader_num_batches[0] == len(test_loader)
16 changes: 16 additions & 0 deletions tests/trainer/test_data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,3 +335,19 @@ def test_pre_made_batches():
loader = DataLoader(RandomDataset(32, 10), batch_size=None)
trainer = Trainer(fast_dev_run=1)
trainer.predict(LoaderTestModel(), loader)


@RunIf(skip_windows=True)
def test_distributed_sampler_with_overfit_batches(tmpdir):
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
overfit_batches=1,
fast_dev_run=1,
strategy="ddp",
num_processes=2,
)
trainer.fit(model)
train_sampler = trainer.train_dataloader.loaders.sampler
assert isinstance(train_sampler, DistributedSampler)
assert train_sampler.shuffle is False
Loading