Skip to content

Commit cfd82a1

Browse files
committed
change corresponding logic according to #9989
1 parent 346ede5 commit cfd82a1

File tree

2 files changed

+6
-25
lines changed

2 files changed

+6
-25
lines changed

pytorch_lightning/trainer/connectors/debugging_connector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,8 @@ def determine_data_use_amount(self, overfit_batches: float) -> None:
8282
# by overfit_batches
8383
if self.trainer.limit_train_batches > 0.0:
8484
self.trainer.limit_train_batches = overfit_batches
85-
if self.trainer.limit_val_batches > 0.0:
86-
self.trainer.limit_val_batches = overfit_batches
85+
# Disable validation completely when overfit_batches > 0
86+
self.trainer.limit_val_batches = 0.0
8787
if self.trainer.limit_test_batches > 0.0:
8888
self.trainer.limit_test_batches = overfit_batches
8989

tests/trainer/test_trainer.py

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -908,10 +908,8 @@ def validation_epoch_end(self, *args, **kwargs):
908908
assert model.validation_epoch_end_invoked, "did not run `validation_epoch_end` with `fast_dev_run=True`"
909909

910910

911-
def test_irrevelance_between_non_default_overfit_batches_and_non_default_batch_limitation(tmpdir):
912-
"""Verify that when `overfit_batches` > 0, `limit_train/val/test_batches` won't be resetted to
913-
`overfit_batches` unless they are of default value."""
914-
"""Assure that non-default value of `limit_train/val/test_batches` won't be reset by `DebuggingConnector` when `overfit_batches` > 0"""
911+
def test_disable_validation_when_overfit_batches_larger_than_zero(tmpdir):
912+
"""Verify that when `overfit_batches` > 0, there will be no validation"""
915913

916914
class CurrentModel(EvalModelTemplate):
917915

@@ -942,28 +940,11 @@ def validation_epoch_end(self, *args, **kwargs):
942940
trainer = Trainer(**trainer_options)
943941
trainer.fit(model)
944942

945-
# check that limit_xxx_batches won't be reset when they are non-default value and overfit_batches > 0
946943
assert trainer.state.finished, f"Training failed with {trainer.state}"
947944
assert trainer.current_epoch == 1
948945
assert trainer.limit_train_batches == 1
949-
assert trainer.limit_val_batches == 0.0
950-
assert not model.validation_step_invoked, "`validation_step` should not run when `limit_val_batches=0`"
951-
assert not model.validation_epoch_end_invoked, "`validation_epoch_end` should not run when `limit_val_batches=0`"
952-
953-
# check that limit_xxx_batches will be reset when they are default value and overfit_batches > 0
954-
model = CurrentModel(**hparams)
955-
trainer_options.update(overfit_batches=2)
956-
trainer_options.update(limit_train_batches=1.0)
957-
trainer_options.update(limit_val_batches=1.0)
958-
trainer = Trainer(**trainer_options)
959-
trainer.fit(model)
960-
961-
assert trainer.state.finished, f"Training failed with {trainer.state}"
962-
assert trainer.current_epoch == 1
963-
assert trainer.limit_train_batches == 2
964-
assert trainer.limit_val_batches == 2
965-
assert model.validation_step_invoked, "did not run `validation_step` with `fast_dev_run=True`"
966-
assert model.validation_epoch_end_invoked, "did not run `validation_epoch_end` with `fast_dev_run=True`"
946+
assert not model.validation_step_invoked, "`validation_step` should not run when `overfit_batches>0`"
947+
assert not model.validation_epoch_end_invoked, "`validation_step` should not run when `overfit_batches>0`"
967948

968949

969950
@mock.patch("torch.Tensor.backward")

0 commit comments

Comments
 (0)