Skip to content

Commit 88a00d0

Browse files
committed
add docs and tests
1 parent 7b19a02 commit 88a00d0

File tree

3 files changed

+104
-16
lines changed

3 files changed

+104
-16
lines changed

pytorch_lightning/callbacks/batch_size_finder.py

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,42 @@ def __init__(
4848
batch_arg_name="batch_size",
4949
early_exit=False,
5050
):
51+
"""Callback try to find the largest batch size for a given model that does not give an out of memory (OOM)
52+
error. It works with both training and evalation. All you need to do is add it as a callback inside Trainer
53+
and call ``trainer.fit/validate/test/predict()``. Internally it calls the respective step function
54+
``steps_per_trial`` times for each batch size until one of the batch size generates and OOM error.
5155
56+
Args:
57+
mode: search strategy to update the batch size:
58+
59+
- ``'power'`` (default): Keep multiplying the batch size by 2, until we get an OOM error.
60+
- ``'binsearch'``: Initially keep multiplying by 2 and after encountering an OOM error
61+
do a binary search between the last successful batch size and the batch size that failed.
62+
63+
steps_per_trial: number of steps to run with a given batch size.
64+
Ideally 1 should be enough to test if a OOM error occurs,
65+
however in practice a few are needed.
66+
67+
init_val: initial batch size to start the search with.
68+
69+
max_trials: max number of increase in batch size done before
70+
algorithm is terminated
71+
72+
batch_arg_name: name of the attribute that stores the batch size.
73+
It is expected that the user has provided a model or datamodule that has a hyperparameter
74+
with that name. We will look for this attribute name in the following places
75+
76+
- ``model``
77+
- ``model.hparams``
78+
- ``trainer.datamodule`` (the datamodule passed to the tune method)
79+
80+
early_exit: whether to continue with the training/evaluation or stop after
81+
an optimal batch size has been found.
82+
"""
83+
supported_modes = ("power", "binsearch")
5284
mode = mode.lower()
53-
if mode not in ("power", "binsearch"):
54-
raise MisconfigurationException("`mode` should be either 'power' or 'binsearch'")
85+
if mode not in supported_modes:
86+
raise MisconfigurationException(f"`mode` should be one of {supported_modes}")
5587

5688
self.mode = mode
5789
self.steps_per_trial = steps_per_trial
@@ -121,6 +153,10 @@ def scale_batch_size(self, trainer, pl_module):
121153
if fs.exists(save_path):
122154
fs.rm(save_path)
123155

156+
# global step and current epoch are incremented before saved in checkpoint
157+
trainer.fit_loop.global_step -= 1
158+
trainer.fit_loop.current_epoch -= 1
159+
124160
self._restore_params(trainer)
125161

126162
if trainer.progress_bar_callback:
@@ -165,7 +201,7 @@ def _run_binary_scaling(self, trainer, pl_module, new_size):
165201
while True:
166202
garbage_collection_cuda()
167203
try:
168-
# Try fit
204+
# run loop
169205
self._try_loop_run(trainer)
170206
count += 1
171207
if count > self.max_trials:
@@ -217,7 +253,7 @@ def _try_loop_run(self, trainer):
217253
elif trainer.state.fn == TrainerFn.PREDICTING:
218254
loop = trainer.predict_loop
219255

220-
loop.load_state_dict(deepcopy(self._dumped_params["loop_state_dict"]))
256+
loop.load_state_dict(deepcopy(self._dumped_params["loop_state_dict"]), force_load_progress=True)
221257
loop.run()
222258

223259
@staticmethod
@@ -292,16 +328,16 @@ def _restore_params(self, trainer):
292328
loop = trainer.predict_loop
293329
trainer.limit_predict_batches = self._dumped_params["limit_predict_batches"]
294330

295-
loop.load_state_dict(deepcopy(self._dumped_params["loop_state_dict"]))
331+
loop.load_state_dict(deepcopy(self._dumped_params["loop_state_dict"]), force_load_progress=True)
296332
if "loop_verbose" in self._dumped_params:
297333
loop.verbose = self._dumped_params["loop_verbose"]
298334

299335
def pre_early_exit(self, trainer):
300336
if trainer.fast_dev_run:
301337
return
302338

339+
# this is required to stop the respective loops
303340
if trainer.state.fn == TrainerFn.FITTING:
304-
trainer.should_stop = True
305341
self._dumped_params["num_training_batches"] = trainer.num_training_batches
306342
trainer.num_training_batches = 0
307343
elif trainer.state.fn == TrainerFn.VALIDATING:
@@ -318,6 +354,7 @@ def post_early_exit(self, trainer):
318354
if trainer.fast_dev_run:
319355
return
320356

357+
# restore the state used to stop the respective loop
321358
if trainer.state.fn == TrainerFn.FITTING:
322359
trainer.num_training_batches = self._dumped_params["num_training_batches"]
323360
loop = trainer.fit_loop
@@ -331,7 +368,7 @@ def post_early_exit(self, trainer):
331368
trainer.num_predict_batches = self._dumped_params["num_predict_batches"]
332369
loop = trainer.predict_loop
333370

334-
loop.load_state_dict(self._dumped_params["loop_state_dict"])
371+
loop.load_state_dict(self._dumped_params["loop_state_dict"], force_load_progress=True)
335372
trainer.callbacks = [cb for cb in trainer.callbacks if not isinstance(cb, BatchSizeFinder)]
336373

337374
def on_fit_start(self, trainer, pl_module):
@@ -346,6 +383,8 @@ def on_validation_start(self, trainer, pl_module):
346383
if trainer.sanity_checking or trainer.state.fn != TrainerFn.VALIDATING:
347384
return
348385

386+
self.scale_batch_size(trainer, pl_module)
387+
349388
if self.early_exit:
350389
self.pre_early_exit(trainer)
351390
else:

pytorch_lightning/loops/base.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -276,10 +276,9 @@ def state_dict(self, destination: Optional[Dict] = None, prefix: str = "") -> Di
276276
destination[prefix + "state_dict"] = self.on_save_checkpoint()
277277

278278
# do not get the mode from `self.trainer` because it might not have been attached yet
279-
ft_enabled = _FaultTolerantMode.detect_current_mode().is_enabled
280279
for k, v in self.__dict__.items():
281280
key = prefix + k
282-
if ft_enabled and isinstance(v, BaseProgress):
281+
if isinstance(v, BaseProgress):
283282
destination[key] = v.state_dict()
284283
elif isinstance(v, Loop):
285284
v.state_dict(destination, key + ".")
@@ -296,21 +295,30 @@ def load_state_dict(
296295
state_dict: Dict,
297296
prefix: str = "",
298297
metrics: Optional[Dict[str, Metric]] = None,
298+
force_load_progress: bool = False,
299299
) -> None:
300300
"""Loads the state of this loop and all its children."""
301-
self._load_from_state_dict(state_dict.copy(), prefix, metrics)
301+
self._load_from_state_dict(state_dict.copy(), prefix, metrics, force_load_progress)
302302
for k, v in self.__dict__.items():
303303
if isinstance(v, Loop):
304-
v.load_state_dict(state_dict.copy(), prefix + k + ".")
304+
v.load_state_dict(state_dict.copy(), prefix + k + ".", force_load_progress=force_load_progress)
305+
306+
def _load_from_state_dict(
307+
self,
308+
state_dict: Dict,
309+
prefix: str,
310+
metrics: Optional[Dict[str, Metric]] = None,
311+
force_load_progress: bool = False,
312+
) -> None:
313+
load_progress = _FaultTolerantMode.detect_current_mode().is_enabled or force_load_progress
305314

306-
def _load_from_state_dict(self, state_dict: Dict, prefix: str, metrics: Optional[Dict[str, Metric]] = None) -> None:
307315
for k, v in self.__dict__.items():
308316
key = prefix + k
309317
if key not in state_dict:
310318
# no state for this object, maybe we are loading an old checkpoint
311319
continue
312320

313-
if isinstance(v, BaseProgress):
321+
if load_progress and isinstance(v, BaseProgress):
314322
v.load_state_dict(state_dict[key])
315323
elif (
316324
isinstance(v, _ResultCollection)

tests/tuner/test_scale_batch_size.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import tests.helpers.utils as tutils
2222
from pytorch_lightning import Trainer
23+
from pytorch_lightning.callbacks.batch_size_finder import BatchSizeFinder
2324
from pytorch_lightning.tuner.tuning import Tuner
2425
from pytorch_lightning.utilities import AMPType
2526
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@@ -49,6 +50,12 @@ def train_dataloader(self):
4950
def val_dataloader(self):
5051
return DataLoader(RandomDataset(32, 64), batch_size=getattr(self, "batch_size", 1))
5152

53+
def test_dataloader(self):
54+
return DataLoader(RandomDataset(32, 64), batch_size=getattr(self, "batch_size", 1))
55+
56+
def predict_dataloader(self):
57+
return DataLoader(RandomDataset(32, 64), batch_size=getattr(self, "batch_size", 1))
58+
5259

5360
@pytest.mark.parametrize(["model_bs", "dm_bs"], [(2, -1), (2, 2), (2, None), (None, 2), (16, 16)])
5461
def test_scale_batch_size_method_with_model_or_datamodule(tmpdir, model_bs, dm_bs):
@@ -133,7 +140,7 @@ def test_auto_scale_batch_size_trainer_arg(tmpdir, scale_arg):
133140
after_batch_size = model.batch_size
134141
assert before_batch_size != after_batch_size, "Batch size was not altered after running auto scaling of batch size"
135142

136-
assert not os.path.exists(tmpdir / "scale_batch_size_temp_model.ckpt")
143+
assert not any(f for f in os.listdir(tmpdir) if f.startswith(".scale_batch_size_temp_model"))
137144

138145

139146
@RunIf(min_gpus=1)
@@ -275,9 +282,9 @@ def __init__(self):
275282
auto_scale_batch_size="ThisModeDoesNotExist",
276283
)
277284

278-
with pytest.raises(MisconfigurationException, match="should be either 'power' or 'binsearch'"):
285+
with pytest.raises(MisconfigurationException, match="should be one of"):
279286
trainer.tune(model)
280-
with pytest.raises(MisconfigurationException, match="should be either 'power' or 'binsearch'"):
287+
with pytest.raises(MisconfigurationException, match="should be one of"):
281288
trainer.tuner.scale_batch_size(model, mode="ThisModeDoesNotExist")
282289

283290

@@ -292,3 +299,37 @@ def test_dataloader_reset_with_scale_batch_size(tmpdir, scale_method):
292299

293300
assert trainer.train_dataloader.loaders.batch_size == new_batch_size
294301
assert trainer.val_dataloaders[0].batch_size == new_batch_size
302+
303+
304+
@pytest.mark.parametrize("trainer_fn", ["fit", "validate", "test", "predict"])
305+
@pytest.mark.parametrize("early_exit", [False])
306+
# @pytest.mark.parametrize('early_exit', [True, False])
307+
def test_batch_size_finder_callback(tmpdir, trainer_fn, early_exit):
308+
"""Test batch size finder callback with different trainer methods."""
309+
tutils.reset_seed()
310+
before_batch_size = 2
311+
model = BatchSizeModel(batch_size=before_batch_size)
312+
batch_size_finder = BatchSizeFinder(max_trials=4, batch_arg_name="batch_size", early_exit=early_exit)
313+
trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, callbacks=[batch_size_finder])
314+
fn = getattr(trainer, trainer_fn)
315+
fn(model)
316+
after_batch_size = model.batch_size
317+
loop = getattr(trainer, f"{trainer_fn}_loop")
318+
319+
if early_exit:
320+
trainer.global_step == 0
321+
trainer.current_epoch == 0
322+
if trainer_fn != "fit":
323+
assert loop.dataloader_progress.current.completed == 0
324+
assert loop.epoch_loop.batch_progress.current.completed == 0
325+
else:
326+
if trainer_fn == "fit":
327+
assert trainer.global_step == 4
328+
assert trainer.current_epoch == 1
329+
else:
330+
assert trainer.global_step == 0
331+
assert loop.dataloader_progress.current.completed == 1
332+
assert loop.epoch_loop.batch_progress.current.completed == 2
333+
334+
assert before_batch_size != after_batch_size, "Batch size was not altered after running auto scaling of batch size"
335+
assert not any(f for f in os.listdir(tmpdir) if f.startswith(".scale_batch_size_temp_model"))

0 commit comments

Comments
 (0)