-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Add BatchSizeFinder callback #11089
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Add BatchSizeFinder callback #11089
Changes from 38 commits
Commits
Show all changes
54 commits
Select commit
Hold shift + click to select a range
f704d45
add BatchSizeFinderCallback callback
rohitgr7 9ed344d
temp rm from init
rohitgr7 b8e4b56
skip with lr_finder tests
rohitgr7 b4f6dea
restore loops and intergrate early exit
rohitgr7 5da0dd2
enable fast_dev_run test
rohitgr7 b30a207
add docs and tests
rohitgr7 dd909bf
keep tune and remove early_exit
rohitgr7 edfe220
add more tests
rohitgr7 6b99b72
patch lr finder
rohitgr7 9f79123
disable skip
rohitgr7 1361559
force_save and fix test
rohitgr7 7fe3106
mypy and circular import fix
rohitgr7 9097648
fix mypy
rohitgr7 8d9008a
fix
rohitgr7 4405236
updates
rohitgr7 a71b953
rebase
rohitgr7 e7431c9
address reviews
rohitgr7 cd5c36f
add more exceptions for unsupported functionalities
rohitgr7 0312c97
move exception to setup
rohitgr7 677c156
chlog
rohitgr7 5783fbf
unit test
rohitgr7 566ee98
address reviews
rohitgr7 847f971
Apply suggestions from code review
rohitgr7 ce4ee66
update
rohitgr7 e099004
update
rohitgr7 ff0b227
mypy
rohitgr7 58575d6
Merge branch 'master' into ref/bs_finder_tuner
rohitgr7 7d91fce
Merge branch 'master' into ref/bs_finder_tuner
rohitgr7 7922e2d
fix
rohitgr7 e294c42
use it as a util func
rohitgr7 0127310
Merge remote-tracking branch 'origin/master' into ref/bs_finder_tuner
rohitgr7 1531483
license
rohitgr7 3dc10dc
Merge branch 'master' into ref/bs_finder_tuner
rohitgr7 deb1424
Merge branch 'master' into ref/bs_finder_tuner
rohitgr7 61e74f4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] f7ca854
mypy
rohitgr7 269a5b6
Merge branch 'master' into ref/bs_finder_tuner
rohitgr7 c1dd622
mypy
rohitgr7 795bd28
Merge branch 'master' into ref/bs_finder_tuner
rohitgr7 cb1c872
review
rohitgr7 361fe6c
fix
rohitgr7 84f2200
Merge branch 'master' into ref/bs_finder_tuner
rohitgr7 2952700
Merge branch 'master' into ref/bs_finder_tuner
rohitgr7 c1d8315
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 0b1fd82
fix
rohitgr7 7ef720f
Merge branch 'master' into ref/bs_finder_tuner
otaj efc7bf1
updates
rohitgr7 c23169a
updates
rohitgr7 8cfb386
Merge branch 'master' into ref/bs_finder_tuner
rohitgr7 df4e2c1
fix import
rohitgr7 aa09a60
Protect callback attrs
carmocca c0118da
don't reset val dataloader
rohitgr7 560bbd9
Merge remote-tracking branch 'origin/master' into ref/bs_finder_tuner
rohitgr7 5418f7f
update test
rohitgr7 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
r""" | ||
BatchSizeFinder | ||
=============== | ||
|
||
Finds optimal batch size | ||
""" | ||
|
||
from typing import Optional | ||
|
||
import pytorch_lightning as pl | ||
from pytorch_lightning.callbacks.callback import Callback | ||
from pytorch_lightning.tuner.batch_size_scaling import scale_batch_size | ||
from pytorch_lightning.utilities.exceptions import _TunerExitException, MisconfigurationException | ||
from pytorch_lightning.utilities.parsing import lightning_hasattr | ||
from pytorch_lightning.utilities.rank_zero import rank_zero_warn | ||
|
||
|
||
class BatchSizeFinder(Callback): | ||
SUPPORTED_MODES = ("power", "binsearch") | ||
|
||
def __init__( | ||
self, | ||
mode: str = "power", | ||
steps_per_trial: int = 3, | ||
init_val: int = 2, | ||
max_trials: int = 25, | ||
batch_arg_name: str = "batch_size", | ||
) -> None: | ||
"""The `BatchSizeFinder` callback tries to find the largest batch size for a given model that does not give | ||
an out of memory (OOM) error. It works with both training and evalation. All you need to do is add it as a | ||
callback inside Trainer and call ``trainer.fit/validate/test/predict()``. Internally it calls the | ||
respective step function ``steps_per_trial`` times for each batch size until one of the batch size | ||
generates and OOM error. | ||
rohitgr7 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Args: | ||
mode: search strategy to update the batch size: | ||
|
||
- ``'power'``: Keep multiplying the batch size by 2, until we get an OOM error. | ||
- ``'binsearch'``: Initially keep multiplying by 2 and after encountering an OOM error | ||
do a binary search between the last successful batch size and the batch size that failed. | ||
|
||
steps_per_trial: number of steps to run with a given batch size. | ||
Ideally 1 should be enough to test if a OOM error occurs, | ||
rohitgr7 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
however in practice a few are needed. | ||
|
||
init_val: initial batch size to start the search with. | ||
|
||
max_trials: max number of increase in batch size done before | ||
rohitgr7 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
algorithm is terminated | ||
|
||
batch_arg_name: name of the attribute that stores the batch size. | ||
It is expected that the user has provided a model or datamodule that has a hyperparameter | ||
with that name. We will look for this attribute name in the following places | ||
|
||
- ``model`` | ||
- ``model.hparams`` | ||
- ``trainer.datamodule`` (the datamodule passed to the tune method) | ||
""" | ||
# TODO: Add input validation. | ||
rohitgr7 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
mode = mode.lower() | ||
if mode not in self.SUPPORTED_MODES: | ||
raise MisconfigurationException(f"`mode` should be either of {self.SUPPORTED_MODES}") | ||
rohitgr7 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
self.mode = mode | ||
self.steps_per_trial = steps_per_trial | ||
self.init_val = init_val | ||
self.max_trials = max_trials | ||
self.batch_arg_name = batch_arg_name | ||
self.optimal_batch_size = init_val | ||
self._early_exit = False | ||
|
||
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None: | ||
if trainer._accelerator_connector.is_distributed: | ||
raise MisconfigurationException("Batch size finder is not supported with distributed strategies.") | ||
|
||
running_stage = trainer.state.stage | ||
assert running_stage is not None | ||
dl_source = getattr(trainer._data_connector, f"_{running_stage.dataloader_prefix}_dataloader_source") | ||
|
||
# TODO: check if this can be enabled (#4040) | ||
if not trainer._data_connector._train_dataloader_source.is_module(): | ||
raise MisconfigurationException( | ||
"Batch size finder cannot be used with dataloaders passed directly to `.fit()`. Please disable" | ||
" the feature or incorporate the dataloader into your LightningModule or LightningDataModule." | ||
) | ||
|
||
# TODO: Add support for multiple eval dataloader | ||
if stage != "fit": | ||
dataloaders = dl_source.dataloader() | ||
if isinstance(dataloaders, list) and len(dataloaders) > 1: | ||
raise MisconfigurationException( | ||
"Batch size finder cannot be used with multiple" f" {running_stage.dataloader_prefix} dataloaders." | ||
rohitgr7 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
|
||
if not lightning_hasattr(pl_module, self.batch_arg_name): | ||
raise MisconfigurationException( | ||
f"Field {self.batch_arg_name} not found in both `model` and `model.hparams`" | ||
) | ||
|
||
if ( | ||
hasattr(pl_module, self.batch_arg_name) | ||
and hasattr(pl_module, "hparams") | ||
and self.batch_arg_name in pl_module.hparams | ||
): | ||
rank_zero_warn( | ||
f"Field `model.{self.batch_arg_name}` and `model.hparams.{self.batch_arg_name}` are mutually exclusive!" | ||
f" `model.{self.batch_arg_name}` will be used as the initial batch size for scaling." | ||
" If this is not the intended behavior, please remove either one." | ||
) | ||
|
||
def scale_batch_size(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: | ||
new_size = scale_batch_size( | ||
trainer, pl_module, self.mode, self.steps_per_trial, self.init_val, self.max_trials, self.batch_arg_name | ||
) | ||
|
||
self.optimal_batch_size = new_size | ||
if self._early_exit: | ||
raise _TunerExitException() | ||
|
||
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: | ||
self.scale_batch_size(trainer, pl_module) | ||
|
||
def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: | ||
if trainer.sanity_checking or trainer.state.fn != "validate": | ||
return | ||
|
||
self.scale_batch_size(trainer, pl_module) | ||
|
||
def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: | ||
self.scale_batch_size(trainer, pl_module) | ||
|
||
def on_predict_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: | ||
self.scale_batch_size(trainer, pl_module) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.