Skip to content

Add BatchSizeFinder callback #11089

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 54 commits into from
Sep 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
f704d45
add BatchSizeFinderCallback callback
rohitgr7 Dec 14, 2021
9ed344d
temp rm from init
rohitgr7 Dec 15, 2021
b8e4b56
skip with lr_finder tests
rohitgr7 Dec 15, 2021
b4f6dea
restore loops and intergrate early exit
rohitgr7 Dec 28, 2021
5da0dd2
enable fast_dev_run test
rohitgr7 Dec 28, 2021
b30a207
add docs and tests
rohitgr7 Jan 4, 2022
dd909bf
keep tune and remove early_exit
rohitgr7 Jan 5, 2022
edfe220
add more tests
rohitgr7 Jan 5, 2022
6b99b72
patch lr finder
rohitgr7 Jan 5, 2022
9f79123
disable skip
rohitgr7 Jan 5, 2022
1361559
force_save and fix test
rohitgr7 Jan 5, 2022
7fe3106
mypy and circular import fix
rohitgr7 Jan 5, 2022
9097648
fix mypy
rohitgr7 Jan 6, 2022
8d9008a
fix
rohitgr7 Jan 6, 2022
4405236
updates
rohitgr7 Jan 16, 2022
a71b953
rebase
rohitgr7 Feb 21, 2022
e7431c9
address reviews
rohitgr7 Feb 21, 2022
cd5c36f
add more exceptions for unsupported functionalities
rohitgr7 Feb 21, 2022
0312c97
move exception to setup
rohitgr7 Feb 21, 2022
677c156
chlog
rohitgr7 Feb 21, 2022
5783fbf
unit test
rohitgr7 Feb 21, 2022
566ee98
address reviews
rohitgr7 Feb 22, 2022
847f971
Apply suggestions from code review
rohitgr7 Feb 23, 2022
ce4ee66
update
rohitgr7 Jun 29, 2022
e099004
update
rohitgr7 Jun 29, 2022
ff0b227
mypy
rohitgr7 Jun 30, 2022
58575d6
Merge branch 'master' into ref/bs_finder_tuner
rohitgr7 Jun 30, 2022
7d91fce
Merge branch 'master' into ref/bs_finder_tuner
rohitgr7 Jul 12, 2022
7922e2d
fix
rohitgr7 Jul 12, 2022
e294c42
use it as a util func
rohitgr7 Jul 13, 2022
0127310
Merge remote-tracking branch 'origin/master' into ref/bs_finder_tuner
rohitgr7 Jul 13, 2022
1531483
license
rohitgr7 Jul 13, 2022
3dc10dc
Merge branch 'master' into ref/bs_finder_tuner
rohitgr7 Jul 19, 2022
deb1424
Merge branch 'master' into ref/bs_finder_tuner
rohitgr7 Aug 9, 2022
61e74f4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 9, 2022
f7ca854
mypy
rohitgr7 Aug 10, 2022
269a5b6
Merge branch 'master' into ref/bs_finder_tuner
rohitgr7 Aug 10, 2022
c1dd622
mypy
rohitgr7 Aug 10, 2022
795bd28
Merge branch 'master' into ref/bs_finder_tuner
rohitgr7 Aug 25, 2022
cb1c872
review
rohitgr7 Aug 25, 2022
361fe6c
fix
rohitgr7 Aug 26, 2022
84f2200
Merge branch 'master' into ref/bs_finder_tuner
rohitgr7 Aug 26, 2022
2952700
Merge branch 'master' into ref/bs_finder_tuner
rohitgr7 Sep 10, 2022
c1d8315
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 10, 2022
0b1fd82
fix
rohitgr7 Sep 10, 2022
7ef720f
Merge branch 'master' into ref/bs_finder_tuner
otaj Sep 15, 2022
efc7bf1
updates
rohitgr7 Sep 16, 2022
c23169a
updates
rohitgr7 Sep 16, 2022
8cfb386
Merge branch 'master' into ref/bs_finder_tuner
rohitgr7 Sep 16, 2022
df4e2c1
fix import
rohitgr7 Sep 16, 2022
aa09a60
Protect callback attrs
carmocca Sep 19, 2022
c0118da
don't reset val dataloader
rohitgr7 Sep 24, 2022
560bbd9
Merge remote-tracking branch 'origin/master' into ref/bs_finder_tuner
rohitgr7 Sep 24, 2022
5418f7f
update test
rohitgr7 Sep 24, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source-pytorch/api_references.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ callbacks
BackboneFinetuning
BaseFinetuning
BasePredictionWriter
BatchSizeFinder
Callback
DeviceStatsMonitor
EarlyStopping
Expand Down
1 change: 1 addition & 0 deletions docs/source-pytorch/extensions/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ Lightning has a few built-in callbacks.
BackboneFinetuning
BaseFinetuning
BasePredictionWriter
BatchSizeFinder
Callback
DeviceStatsMonitor
EarlyStopping
Expand Down
5 changes: 5 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added `BatchSizeFinder` callback ([#11089](https://github.com/PyTorchLightning/pytorch-lightning/pull/11089))


- Tuner now supports a new `method` argument which will determine when to run the `BatchSizeFinder`: one of `fit`, `validate`, `test` or `predict` ([#11089](https://github.com/PyTorchLightning/pytorch-lightning/pull/11089))


- Added prefix to log message in `seed_everything` with rank info ([#13290](https://github.com/Lightning-AI/lightning/issues/13290))

Expand Down
4 changes: 3 additions & 1 deletion src/pytorch_lightning/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.callbacks.batch_size_finder import BatchSizeFinder
from pytorch_lightning.callbacks.callback import Callback
from pytorch_lightning.callbacks.checkpoint import Checkpoint
from pytorch_lightning.callbacks.device_stats_monitor import DeviceStatsMonitor
Expand All @@ -32,6 +33,8 @@
__all__ = [
"BackboneFinetuning",
"BaseFinetuning",
"BasePredictionWriter",
"BatchSizeFinder",
"Callback",
"Checkpoint",
"DeviceStatsMonitor",
Expand All @@ -42,7 +45,6 @@
"ModelCheckpoint",
"ModelPruning",
"ModelSummary",
"BasePredictionWriter",
"ProgressBarBase",
"QuantizationAwareTraining",
"RichModelSummary",
Expand Down
149 changes: 149 additions & 0 deletions src/pytorch_lightning/callbacks/batch_size_finder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""
BatchSizeFinder
===============

Finds optimal batch size
"""

from typing import Optional

import pytorch_lightning as pl
from pytorch_lightning.callbacks.callback import Callback
from pytorch_lightning.tuner.batch_size_scaling import scale_batch_size
from pytorch_lightning.utilities.exceptions import _TunerExitException, MisconfigurationException
from pytorch_lightning.utilities.parsing import lightning_hasattr
from pytorch_lightning.utilities.rank_zero import rank_zero_warn


class BatchSizeFinder(Callback):
SUPPORTED_MODES = ("power", "binsearch")

def __init__(
self,
mode: str = "power",
steps_per_trial: int = 3,
init_val: int = 2,
max_trials: int = 25,
batch_arg_name: str = "batch_size",
) -> None:
"""The ``BatchSizeFinder`` callback tries to find the largest batch size for a given model that does not
give an out of memory (OOM) error. All you need to do is add it as a callback inside Trainer and call
``trainer.{fit,validate,test,predict}``. Internally it calls the respective step function
``steps_per_trial`` times for each batch size until one of the batch sizes generates an OOM error.

Args:
mode: search strategy to update the batch size:

- ``'power'``: Keep multiplying the batch size by 2, until we get an OOM error.
- ``'binsearch'``: Initially keep multiplying by 2 and after encountering an OOM error
do a binary search between the last successful batch size and the batch size that failed.

steps_per_trial: number of steps to run with a given batch size.
Ideally 1 should be enough to test if an OOM error occurs,
however in practice a few are needed.

init_val: initial batch size to start the search with.

max_trials: max number of increases in batch size done before
algorithm is terminated

batch_arg_name: name of the attribute that stores the batch size.
It is expected that the user has provided a model or datamodule that has a hyperparameter
with that name. We will look for this attribute name in the following places

- ``model``
- ``model.hparams``
- ``trainer.datamodule`` (the datamodule passed to the tune method)
"""
mode = mode.lower()
if mode not in self.SUPPORTED_MODES:
raise ValueError(f"`mode` should be either of {self.SUPPORTED_MODES}")
self.optimal_batch_size = init_val
self._mode = mode
self._steps_per_trial = steps_per_trial
self._init_val = init_val
self._max_trials = max_trials
self._batch_arg_name = batch_arg_name
self._early_exit = False

def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
if trainer._accelerator_connector.is_distributed:
raise MisconfigurationException("The Batch size finder is not supported with distributed strategies.")

running_stage = trainer.state.stage
assert running_stage is not None
dl_source = getattr(trainer._data_connector, f"_{running_stage.dataloader_prefix}_dataloader_source")

# TODO: check if this can be enabled (#4040)
if not trainer._data_connector._train_dataloader_source.is_module():
raise MisconfigurationException(
"The Batch size finder cannot be used with dataloaders passed directly to `.fit()`. Please disable"
" the feature or incorporate the dataloader into your LightningModule or LightningDataModule."
)

# TODO: Add support for multiple eval dataloader
if stage != "fit":
dataloaders = dl_source.dataloader()
if isinstance(dataloaders, list) and len(dataloaders) > 1:
raise MisconfigurationException(
f"The Batch size finder cannot be used with multiple {running_stage.dataloader_prefix} dataloaders."
)

if not lightning_hasattr(pl_module, self._batch_arg_name):
raise MisconfigurationException(
f"Field {self._batch_arg_name} not found in both `model` and `model.hparams`"
)

if (
hasattr(pl_module, self._batch_arg_name)
and hasattr(pl_module, "hparams")
and self._batch_arg_name in pl_module.hparams
):
rank_zero_warn(
f"Field `model.{self._batch_arg_name}` and `model.hparams.{self._batch_arg_name}` are mutually"
f" exclusive! `model.{self._batch_arg_name}` will be used as the initial batch size for scaling."
" If this is not the intended behavior, please remove either one."
)

def scale_batch_size(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
new_size = scale_batch_size(
trainer,
pl_module,
self._mode,
self._steps_per_trial,
self._init_val,
self._max_trials,
self._batch_arg_name,
)

self.optimal_batch_size = new_size
if self._early_exit:
raise _TunerExitException()

def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self.scale_batch_size(trainer, pl_module)

def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if trainer.sanity_checking or trainer.state.fn != "validate":
return

self.scale_batch_size(trainer, pl_module)

def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self.scale_batch_size(trainer, pl_module)

def on_predict_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self.scale_batch_size(trainer, pl_module)
26 changes: 19 additions & 7 deletions src/pytorch_lightning/trainer/connectors/callback_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
RichProgressBar,
TQDMProgressBar,
)
from pytorch_lightning.callbacks.batch_size_finder import BatchSizeFinder
from pytorch_lightning.callbacks.rich_model_summary import RichModelSummary
from pytorch_lightning.callbacks.timer import Timer
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -229,19 +230,30 @@ def _attach_model_callbacks(self) -> None:

@staticmethod
def _reorder_callbacks(callbacks: List[Callback]) -> List[Callback]:
"""Moves all Checkpoint callbacks to the end of the list. The sequential order within the group of
checkpoint callbacks is preserved, as well as the order of all other callbacks.
"""Moves all the tuner specific callbacks at the beginning of the list and all the `ModelCheckpoint`
callbacks to the end of the list. The sequential order within the group of checkpoint callbacks is
preserved, as well as the order of all other callbacks.

Args:
callbacks: A list of callbacks.

Return:
A new list in which the last elements are Checkpoint if there were any present in the
input.
A new list in which the first elements are tuner specific callbacks and last elements are ModelCheckpoints
if there were any present in the input.
"""
checkpoints: List[Callback] = [c for c in callbacks if isinstance(c, Checkpoint)]
not_checkpoints = [c for c in callbacks if not isinstance(c, Checkpoint)]
return not_checkpoints + checkpoints
tuner_callbacks: List[Callback] = []
other_callbacks: List[Callback] = []
checkpoint_callbacks: List[Callback] = []

for cb in callbacks:
if isinstance(cb, BatchSizeFinder):
tuner_callbacks.append(cb)
elif isinstance(cb, Checkpoint):
checkpoint_callbacks.append(cb)
else:
other_callbacks.append(cb)

return tuner_callbacks + other_callbacks + checkpoint_callbacks


def _configure_external_callbacks() -> List[Callback]:
Expand Down
8 changes: 8 additions & 0 deletions src/pytorch_lightning/trainer/teardown.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from lightning_lite.utilities.distributed import distributed_available
from pytorch_lightning.trainer.states import TrainerStatus
from pytorch_lightning.utilities.exceptions import _TunerExitException
from pytorch_lightning.utilities.rank_zero import rank_zero_warn


Expand All @@ -34,6 +35,13 @@ def call_and_handle_interrupt(trainer: Any, trainer_fn: Callable, *args: Any, **
return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
else:
return trainer_fn(*args, **kwargs)

except _TunerExitException:
trainer._call_teardown_hook()
trainer._teardown()
trainer.state.status = TrainerStatus.FINISHED
trainer.state.stage = None

# TODO: Unify both exceptions below, where `KeyboardError` doesn't re-raise
except KeyboardInterrupt as exception:
rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
Expand Down
39 changes: 16 additions & 23 deletions src/pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from torch import Tensor
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from typing_extensions import Literal

import pytorch_lightning as pl
from lightning_lite.utilities.cloud_io import get_filesystem
Expand Down Expand Up @@ -901,9 +902,11 @@ def tune(
model: "pl.LightningModule",
train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None,
val_dataloaders: Optional[EVAL_DATALOADERS] = None,
dataloaders: Optional[EVAL_DATALOADERS] = None,
datamodule: Optional[LightningDataModule] = None,
scale_batch_size_kwargs: Optional[Dict[str, Any]] = None,
lr_find_kwargs: Optional[Dict[str, Any]] = None,
method: Literal["fit", "validate", "test", "predict"] = "fit",
) -> _TunerResult:
r"""
Runs routines to tune hyperparameters before training.
Expand All @@ -917,44 +920,34 @@ def tune(

val_dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them specifying validation samples.

dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them specifying val/test/predict
samples used for running tuner on validation/testing/prediction.

datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`.

scale_batch_size_kwargs: Arguments for :func:`~pytorch_lightning.tuner.batch_size_scaling.scale_batch_size`

lr_find_kwargs: Arguments for :func:`~pytorch_lightning.tuner.lr_finder.lr_find`

method: Method to run tuner on. It can be any of ``("fit", "validate", "test", "predict")``.
"""
if not isinstance(model, pl.LightningModule):
raise TypeError(f"`Trainer.tune()` requires a `LightningModule`, got: {model.__class__.__qualname__}")

Trainer._log_api_event("tune")

self.state.fn = TrainerFn.TUNING
self.state.status = TrainerStatus.RUNNING
self.tuning = True

# if a datamodule comes in as the second arg, then fix it for the user
if isinstance(train_dataloaders, LightningDataModule):
datamodule = train_dataloaders
train_dataloaders = None
# If you supply a datamodule you can't supply train_dataloader or val_dataloaders
if (train_dataloaders is not None or val_dataloaders is not None) and datamodule is not None:
raise MisconfigurationException(
"You cannot pass `train_dataloader` or `val_dataloaders` to `trainer.tune(datamodule=...)`"
)

# links data to the trainer
self._data_connector.attach_data(
model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule
)

with isolate_rng():
result = self.tuner._tune(
model, scale_batch_size_kwargs=scale_batch_size_kwargs, lr_find_kwargs=lr_find_kwargs
model,
train_dataloaders,
val_dataloaders,
dataloaders,
datamodule,
scale_batch_size_kwargs=scale_batch_size_kwargs,
lr_find_kwargs=lr_find_kwargs,
method=method,
)

assert self.state.stopped
self.tuning = False

return result

def _restore_modules_and_callbacks(self, checkpoint_path: Optional[_PATH] = None) -> None:
Expand Down
Loading