Skip to content

Disable validation completely when overfit_batches>0 #9709

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 35 commits into from
Dec 1, 2021
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
cf7b1ef
[fix]bug that cause ineffective of no validation
popfido Sep 26, 2021
ef0b9ee
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 26, 2021
0587a1e
Add CHANGELOG
popfido Sep 26, 2021
bf4c110
Update CHANGELOG.md
popfido Oct 12, 2021
99f196f
update condition logic
popfido Oct 15, 2021
346ede5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 15, 2021
cfd82a1
change corresponding logic according to #9989
popfido Nov 22, 2021
a02b39b
Merge remote-tracking branch 'origin/master' into bugfix/8962
popfido Nov 22, 2021
c741753
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 22, 2021
9e9cb9b
reset val_batches to 0 if `overfit_batches` > 0
popfido Nov 22, 2021
ef8c54a
remove duplicate test case
popfido Nov 23, 2021
943d2ea
update outdated testcases
popfido Nov 23, 2021
01f0901
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 23, 2021
3e83de0
remove redundant assert
popfido Nov 23, 2021
190b3f5
deprecate usage of overfit_batches
popfido Nov 23, 2021
96cde83
do not test validation in overfit_batch testcase
popfido Nov 24, 2021
c2dd74b
Modify changelog
popfido Nov 24, 2021
49371e7
Merge remote-tracking branch 'upstream/master' into bugfix/8962
popfido Nov 24, 2021
1c075a4
Update tests/models/test_cpu.py
popfido Nov 24, 2021
15ec209
Update tests/callbacks/test_early_stopping.py
popfido Nov 24, 2021
1d8c3a3
Update tests/callbacks/test_early_stopping.py
popfido Nov 24, 2021
3554f67
Update tests/callbacks/test_early_stopping.py
popfido Nov 24, 2021
d0d032c
Update pytorch_lightning/trainer/trainer.py
popfido Nov 24, 2021
80cc4b1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 24, 2021
ad167fe
reschedule test cases and fix bugs
popfido Nov 25, 2021
8ae0742
add checker for test
popfido Nov 25, 2021
27231b1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 25, 2021
7bbc790
remove unused import
popfido Nov 25, 2021
68539cd
reduce test cases data scale to pass
popfido Nov 25, 2021
28de952
undo test case scale reduction
popfido Nov 25, 2021
2f78e71
remove redundant document
popfido Nov 29, 2021
f1a9692
Merge remote-tracking branch 'upstream/master' into bugfix/8962
popfido Nov 30, 2021
bfb89bc
Merge branch 'master' into bugfix/8962
popfido Nov 30, 2021
89a76cb
Merge branch 'master' into bugfix/8962
popfido Dec 1, 2021
20a8baf
Update CHANGELOG.md
rohitgr7 Dec 1, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Raised an error if the `batch_size` cannot be inferred from the current batch if it contained a string or was a custom batch object ([#10541](https://github.com/PyTorchLightning/pytorch-lightning/pull/10541))


-
- Disable validation completely when `overfit_batches > 0` ([#9709](https://github.com/PyTorchLightning/pytorch-lightning/pull/9709))


-
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,8 +626,7 @@ def _determine_data_use_amount(self, overfit_batches: float) -> None:
"""Use less data for debugging purposes."""
if overfit_batches > 0:
self.limit_train_batches = overfit_batches
self.limit_val_batches = overfit_batches
self.limit_test_batches = overfit_batches
self.limit_val_batches = 0

def _setup_on_init(self, num_sanity_val_steps: int) -> None:
self._log_device_info()
Expand Down
19 changes: 16 additions & 3 deletions tests/callbacks/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,13 @@ def validation_epoch_end(self, outputs):
early_stopping = EarlyStopping(
monitor="abc", stopping_threshold=stopping_threshold, divergence_threshold=divergence_theshold
)
trainer = Trainer(default_root_dir=tmpdir, callbacks=[early_stopping], overfit_batches=0.20, max_epochs=20)
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[early_stopping],
limit_train_batches=0.2,
limit_val_batches=0.2,
max_epochs=20,
)
trainer.fit(model)
assert trainer.current_epoch == expected_epoch, "early_stopping failed"

Expand All @@ -246,7 +252,13 @@ def validation_epoch_end(self, outputs):

model = CurrentModel()
early_stopping = EarlyStopping(monitor="val_loss", check_finite=True)
trainer = Trainer(default_root_dir=tmpdir, callbacks=[early_stopping], overfit_batches=0.20, max_epochs=10)
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[early_stopping],
limit_train_batches=0.2,
limit_val_batches=0.2,
max_epochs=10,
)
trainer.fit(model)
assert trainer.current_epoch == expected_stop_epoch
assert early_stopping.stopped_epoch == expected_stop_epoch
Expand Down Expand Up @@ -426,7 +438,8 @@ def test_multiple_early_stopping_callbacks(
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=callbacks,
overfit_batches=0.20,
limit_train_batches=0.1,
limit_val_batches=0.1,
max_epochs=20,
strategy=strategy,
accelerator="cpu",
Expand Down
1 change: 0 additions & 1 deletion tests/models/test_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ def validation_step(self, *args, **kwargs):
callbacks=[stopping],
default_root_dir=tmpdir,
gradient_clip_val=1.0,
overfit_batches=0.20,
track_grad_norm=2,
enable_progress_bar=False,
accumulate_grad_batches=2,
Expand Down
34 changes: 0 additions & 34 deletions tests/trainer/flags/test_overfit_batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,40 +19,6 @@
from tests.helpers.boring_model import BoringModel, RandomDataset


def test_overfit_multiple_val_loaders(tmpdir):
"""Tests that overfit batches works with multiple val dataloaders."""
val_dl_count = 2
overfit_batches = 3

class TestModel(BoringModel):
def validation_step(self, batch, batch_idx, dataloader_idx):
output = self.layer(batch[0])
loss = self.loss(batch, output)
return {"x": loss}

def validation_epoch_end(self, outputs) -> None:
pass

def val_dataloader(self):
dls = [torch.utils.data.DataLoader(RandomDataset(32, 64)) for _ in range(val_dl_count)]
return dls

model = TestModel()

trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=2,
overfit_batches=overfit_batches,
log_every_n_steps=1,
enable_model_summary=False,
)

trainer.fit(model)
assert trainer.num_training_batches == overfit_batches
assert len(trainer.num_val_batches) == val_dl_count
assert all(nbatches == overfit_batches for nbatches in trainer.num_val_batches)


@pytest.mark.parametrize("overfit_batches", [1, 2, 0.1, 0.25, 1.0])
def test_overfit_basic(tmpdir, overfit_batches):
"""Tests that only training_step can be used when overfitting."""
Expand Down
39 changes: 26 additions & 13 deletions tests/trainer/test_trainer_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,33 +112,46 @@ def test_overfit_batch_limits(tmpdir):
# ------------------------------------------------------
for split in (RunningStage.VALIDATING, RunningStage.TESTING):

# ------------------------------------------------------
# test overfit_batches action
# ------------------------------------------------------

# ------------------------------------------------------
# test overfit_batches as percent
# ------------------------------------------------------
trainer = Trainer(overfit_batches=0.11)
trainer._data_connector.attach_dataloaders(model)
loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model)
assert loader_num_batches[0] == num_train_samples

# make sure we turned off shuffle for the user
assert isinstance(dataloaders[0].sampler, SequentialSampler)

# make sure the loaders are the same
(xb, yb) = next(iter(dataloaders[0]))
assert torch.eq(xa, xb).all()
assert torch.eq(ya, yb).all()
loader_num_batches, _ = trainer._reset_eval_dataloader(split, model=model)
if split == RunningStage.VALIDATING:
assert loader_num_batches[0] == 0
else:
assert loader_num_batches[0] == len(test_loader)

# ------------------------------------------------------
# test overfit_batches as int
# ------------------------------------------------------
trainer = Trainer(overfit_batches=1)
trainer._data_connector.attach_dataloaders(model)
loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model)
assert loader_num_batches[0] == 1
if split == RunningStage.VALIDATING:
assert loader_num_batches[0] == 0
else:
assert loader_num_batches[0] == len(test_loader)
# make sure we turned off shuffle for the user
assert isinstance(dataloaders[0].sampler, SequentialSampler)

# make sure the loaders are the same
(xb, yb) = next(iter(dataloaders[0]))
assert torch.eq(xa, xb).all()
assert torch.eq(ya, yb).all()

trainer = Trainer(overfit_batches=5)
trainer._data_connector.attach_dataloaders(model)
loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model)
assert loader_num_batches[0] == 5
loader_num_batches, _ = trainer._reset_eval_dataloader(split, model=model)
if split == RunningStage.VALIDATING:
assert loader_num_batches[0] == 0
else:
assert loader_num_batches[0] == len(test_loader)

# ------------------------------------------------------
# test limit_xxx_batches as percent AND int
Expand Down