Skip to content

Commit 6bb4bd6

Browse files
committed
reset val dataloader for binsearch
1 parent 7b4df7b commit 6bb4bd6

File tree

3 files changed

+6
-2
lines changed

3 files changed

+6
-2
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
546546

547547
- Fixed issue with non-init dataclass fields in `apply_to_collection` ([#9963](https://github.com/PyTorchLightning/pytorch-lightning/issues/9963))
548548

549+
- Reset `val_dataloader` in `tuner/batch_size_scaling` for binsearch ([#9975](https://github.com/PyTorchLightning/pytorch-lightning/pull/9975))
550+
549551

550552
## [1.4.9] - 2021-09-30
551553

pytorch_lightning/tuner/batch_size_scaling.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ def _run_binsearch_scaling(
205205
if changed:
206206
# Force the train dataloader to reset as the batch size has changed
207207
trainer.reset_train_dataloader(model)
208+
trainer.reset_val_dataloader(model)
208209
else:
209210
break
210211

tests/tuner/test_scale_batch_size.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,10 +274,11 @@ def __init__(self):
274274
trainer.tuner.scale_batch_size(model, mode="ThisModeDoesNotExist")
275275

276276

277-
def test_dataloader_reset_with_scale_batch_size(tmpdir):
277+
@pytest.mark.parametrize("scale_method", ["power", "binsearch"])
278+
def test_dataloader_reset_with_scale_batch_size(tmpdir, scale_method):
278279
"""Test that train and val dataloaders are reset at every update in scale batch size."""
279280
model = BatchSizeModel(batch_size=16)
280-
scale_batch_size_kwargs = {"max_trials": 5, "init_val": 4}
281+
scale_batch_size_kwargs = {"max_trials": 5, "init_val": 4, "mode": scale_method}
281282

282283
trainer = Trainer(max_epochs=2, auto_scale_batch_size=True)
283284
new_batch_size = trainer.tune(model, scale_batch_size_kwargs=scale_batch_size_kwargs)["scale_batch_size"]

0 commit comments

Comments
 (0)