Skip to content

Commit d71501d

Browse files
authored
Reset val_dataloader in tuner/batch_size_scaling (#9857)
* reset val * chlog
1 parent 8740c80 commit d71501d

File tree

4 files changed

+20
-1
lines changed

4 files changed

+20
-1
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
470470
- Fixed missing arguments when saving hyperparameters from the parent class but not from the child class ([#9800](https://github.com/PyTorchLightning/pytorch-lightning/pull/9800))
471471

472472

473+
- Reset `val_dataloader` in `tuner/batch_size_scaling` ([#9857](https://github.com/PyTorchLightning/pytorch-lightning/pull/9857))
474+
475+
473476
## [1.4.9] - 2021-09-30
474477

475478
- Fixed `lr_find` to generate same results on multiple calls ([#9704](https://github.com/PyTorchLightning/pytorch-lightning/pull/9704))

pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def update_eval_epoch_metrics(self) -> List[_OUT_DICT]:
181181

182182
# log results of evaluation
183183
if (
184-
self.trainer.state.fn != TrainerFn.FITTING
184+
self.trainer.state.fn not in (TrainerFn.FITTING, TrainerFn.TUNING)
185185
and self.trainer.evaluating
186186
and self.trainer.is_global_zero
187187
and self.trainer.verbose_evaluate

pytorch_lightning/tuner/batch_size_scaling.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ def _run_power_scaling(
166166
if changed:
167167
# Force the train dataloader to reset as the batch size has changed
168168
trainer.reset_train_dataloader(model)
169+
trainer.reset_val_dataloader(model)
169170
else:
170171
break
171172
return new_size

tests/tuner/test_scale_batch_size.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ def __init__(self, batch_size):
4848
def train_dataloader(self):
4949
return DataLoader(RandomDataset(32, 64), batch_size=getattr(self, "batch_size", 1))
5050

51+
def val_dataloader(self):
52+
return DataLoader(RandomDataset(32, 64), batch_size=getattr(self, "batch_size", 1))
53+
5154

5255
@pytest.mark.parametrize(["model_bs", "dm_bs"], [(2, -1), (2, 2), (2, None), (None, 2), (16, 16)])
5356
def test_scale_batch_size_method_with_model_or_datamodule(tmpdir, model_bs, dm_bs):
@@ -266,3 +269,15 @@ def __init__(self):
266269
trainer.tune(model)
267270
with pytest.raises(ValueError, match="could either be `power` or `binsearch`"):
268271
trainer.tuner.scale_batch_size(model, mode="ThisModeDoesNotExist")
272+
273+
274+
def test_dataloader_reset_with_scale_batch_size(tmpdir):
275+
"""Test that train and val dataloaders are reset at every update in scale batch size."""
276+
model = BatchSizeModel(batch_size=16)
277+
scale_batch_size_kwargs = {"max_trials": 5, "init_val": 4}
278+
279+
trainer = Trainer(max_epochs=2, auto_scale_batch_size=True)
280+
new_batch_size = trainer.tune(model, scale_batch_size_kwargs=scale_batch_size_kwargs)["scale_batch_size"]
281+
282+
assert trainer.train_dataloader.loaders.batch_size == new_batch_size
283+
assert trainer.val_dataloaders[0].batch_size == new_batch_size

0 commit comments

Comments
 (0)