Skip to content

Commit 4f48372

Browse files
committed
disable eval dataloader replacement
1 parent c647841 commit 4f48372

File tree

4 files changed

+148
-194
lines changed

4 files changed

+148
-194
lines changed

pytorch_lightning/trainer/data_loading.py

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import multiprocessing
1515
import os
1616
from abc import ABC
17-
from copy import deepcopy
1817
from typing import Any, Callable, Collection, List, Optional, Tuple, Union
1918

2019
from torch.utils.data import DataLoader, RandomSampler, Sampler, SequentialSampler
@@ -293,28 +292,14 @@ def _reset_eval_dataloader(
293292
if not isinstance(dataloaders, list):
294293
dataloaders = [dataloaders]
295294

296-
# when overfitting, use the training loader as val and test
297-
# duplicate it the numb of times needed to match the train loaders
298-
if self.overfit_batches > 0:
299-
train_dataloader = self.request_dataloader(RunningStage.TRAINING, model=model)
300-
dataloaders = [deepcopy(train_dataloader) for _ in range(len(dataloaders))]
301-
302295
for loader_i in range(len(dataloaders)):
303296
loader = dataloaders[loader_i]
304297

305298
if hasattr(loader, "sampler") and not isinstance(loader.sampler, SequentialSampler):
306-
# when overfitting, the dataloader should not have sampler
307-
if self.overfit_batches > 0 and mode.evaluating:
308-
rank_zero_warn(
309-
"You requested to overfit but enabled val/test dataloader shuffling."
310-
" We are turning it off for you."
311-
)
312-
dataloaders[loader_i] = _update_dataloader(loader, SequentialSampler(loader.dataset), mode=mode)
313-
else:
314-
rank_zero_warn(
315-
f"Your `{mode.dataloader_prefix}_dataloader` has `shuffle=True`,"
316-
"it is strongly recommended that you turn this off for val/test/predict dataloaders."
317-
)
299+
rank_zero_warn(
300+
f"Your `{mode.dataloader_prefix}_dataloader` has `shuffle=True`,"
301+
"it is strongly recommended that you turn this off for val/test/predict dataloaders."
302+
)
318303

319304
if any(dl is None for dl in dataloaders):
320305
rank_zero_warn("One of given dataloaders is None and it will be skipped.")

tests/trainer/flags/test_overfit_batches.py

Lines changed: 128 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@
1313
# limitations under the License.
1414
import pytest
1515
import torch
16-
from torch.utils.data.sampler import Sampler, SequentialSampler
16+
from torch.utils.data import DataLoader, RandomSampler, Sampler, SequentialSampler
1717

1818
from pytorch_lightning import Trainer
19+
from pytorch_lightning.trainer.states import RunningStage
20+
from tests.base.model_template import EvalModelTemplate
1921
from tests.helpers.boring_model import BoringModel, RandomDataset
2022

2123

@@ -62,3 +64,128 @@ def train_dataloader(self):
6264
trainer.fit(model)
6365

6466
assert isinstance(trainer.train_dataloader.loaders.sampler, SequentialSampler)
67+
68+
69+
def test_overfit_batch_limits(tmpdir):
70+
# ------------------------------------------------------
71+
# Make sure shuffle is correct across loaders initially
72+
# ------------------------------------------------------
73+
model = EvalModelTemplate()
74+
model.train_dataloader()
75+
76+
# original train loader which should be replaced in all methods
77+
train_loader = model.train_dataloader()
78+
79+
# make sure the val and tests are not shuffled
80+
assert isinstance(train_loader.sampler, RandomSampler)
81+
assert isinstance(model.val_dataloader().sampler, SequentialSampler)
82+
assert isinstance(model.test_dataloader().sampler, SequentialSampler)
83+
84+
# ------------------------------------------------------
85+
# get the training loader and batch
86+
# ------------------------------------------------------
87+
# Create a reference train dataloader without shuffling.
88+
train_loader = DataLoader(model.train_dataloader().dataset, shuffle=False)
89+
(xa, ya) = next(iter(train_loader))
90+
train_loader = DataLoader(model.train_dataloader().dataset, shuffle=True)
91+
full_train_samples = len(train_loader)
92+
num_train_samples = int(0.11 * full_train_samples)
93+
94+
# ------------------------------------------------------
95+
# set VAL and Test loaders
96+
# ------------------------------------------------------
97+
val_loader = DataLoader(model.val_dataloader().dataset, shuffle=False)
98+
test_loader = DataLoader(model.test_dataloader().dataset, shuffle=False)
99+
100+
# set the model loaders
101+
model.train_dataloader = lambda: train_loader
102+
model.val_dataloader = lambda: val_loader
103+
model.test_dataloader = lambda: test_loader
104+
105+
# ------------------------------------------------------
106+
# test train loader applies correct limits
107+
# ------------------------------------------------------
108+
trainer = Trainer(overfit_batches=4)
109+
model.trainer = trainer
110+
trainer._data_connector.attach_dataloaders(model=model)
111+
trainer.reset_train_dataloader(model)
112+
assert trainer.num_training_batches == 4
113+
114+
# make sure the loaders are the same
115+
(xb, yb) = next(iter(trainer.train_dataloader))
116+
assert torch.eq(xa, xb).all()
117+
assert torch.eq(ya, yb).all()
118+
119+
trainer = Trainer(overfit_batches=0.11)
120+
model.trainer = trainer
121+
trainer._data_connector.attach_dataloaders(model=model)
122+
trainer.reset_train_dataloader(model)
123+
# The dataloader should have been overwritten with a Sequential sampler.
124+
assert trainer.train_dataloader is not train_loader
125+
assert trainer.num_training_batches == num_train_samples
126+
127+
# make sure the loaders are the same
128+
(xb, yb) = next(iter(trainer.train_dataloader))
129+
assert torch.eq(xa, xb).all()
130+
assert torch.eq(ya, yb).all()
131+
132+
# ------------------------------------------------------
133+
# run tests for both val and test
134+
# ------------------------------------------------------
135+
for split in (RunningStage.VALIDATING, RunningStage.TESTING):
136+
137+
# ------------------------------------------------------
138+
# test overfit_batches as percent
139+
# ------------------------------------------------------
140+
trainer = Trainer(overfit_batches=0.11)
141+
trainer._data_connector.attach_dataloaders(model)
142+
loader_num_batches, _ = trainer._reset_eval_dataloader(split, model=model)
143+
if split == RunningStage.VALIDATING:
144+
assert loader_num_batches[0] == 0
145+
else:
146+
assert loader_num_batches[0] == len(test_loader)
147+
148+
# ------------------------------------------------------
149+
# test overfit_batches as int
150+
# ------------------------------------------------------
151+
trainer = Trainer(overfit_batches=1)
152+
trainer._data_connector.attach_dataloaders(model)
153+
loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model)
154+
if split == RunningStage.VALIDATING:
155+
assert loader_num_batches[0] == 0
156+
else:
157+
assert loader_num_batches[0] == len(test_loader)
158+
# make sure we turned off shuffle for the user
159+
assert isinstance(dataloaders[0].sampler, SequentialSampler)
160+
161+
trainer = Trainer(overfit_batches=5)
162+
trainer._data_connector.attach_dataloaders(model)
163+
loader_num_batches, _ = trainer._reset_eval_dataloader(split, model=model)
164+
if split == RunningStage.VALIDATING:
165+
assert loader_num_batches[0] == 0
166+
else:
167+
assert loader_num_batches[0] == len(test_loader)
168+
169+
# ------------------------------------------------------
170+
# test limit_xxx_batches as percent AND int
171+
# ------------------------------------------------------
172+
if split == RunningStage.VALIDATING:
173+
trainer = Trainer(limit_val_batches=0.1)
174+
trainer._data_connector.attach_dataloaders(model)
175+
loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model)
176+
assert loader_num_batches[0] == int(0.1 * len(val_loader))
177+
178+
trainer = Trainer(limit_val_batches=10)
179+
trainer._data_connector.attach_dataloaders(model)
180+
loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model)
181+
assert loader_num_batches[0] == 10
182+
else:
183+
trainer = Trainer(limit_test_batches=0.1)
184+
trainer._data_connector.attach_dataloaders(model)
185+
loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model)
186+
assert loader_num_batches[0] == int(0.1 * len(test_loader))
187+
188+
trainer = Trainer(limit_test_batches=10)
189+
trainer._data_connector.attach_dataloaders(model)
190+
loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model)
191+
assert loader_num_batches[0] == 10

tests/trainer/test_data_loading.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,3 +335,19 @@ def test_pre_made_batches():
335335
loader = DataLoader(RandomDataset(32, 10), batch_size=None)
336336
trainer = Trainer(fast_dev_run=1)
337337
trainer.predict(LoaderTestModel(), loader)
338+
339+
340+
@RunIf(skip_windows=True)
341+
def test_distributed_sampler_with_overfit_batches(tmpdir):
342+
model = BoringModel()
343+
trainer = Trainer(
344+
default_root_dir=tmpdir,
345+
overfit_batches=1,
346+
fast_dev_run=1,
347+
strategy="ddp_find_unused_parameters_false",
348+
num_processes=1,
349+
)
350+
trainer.fit(model)
351+
train_sampler = trainer.train_dataloader.loaders.sampler
352+
assert isinstance(train_sampler, DistributedSampler)
353+
assert train_sampler.shuffle is False

tests/trainer/test_trainer_tricks.py

Lines changed: 0 additions & 174 deletions
This file was deleted.

0 commit comments

Comments
 (0)