Skip to content

Commit 0f499db

Browse files
committed
Revert "Revert "Add BatchSizeFinder callback (#11089)""
This reverts commit 9cc4695.
1 parent 9cc4695 commit 0f499db

File tree

15 files changed

+640
-169
lines changed

15 files changed

+640
-169
lines changed

docs/source-pytorch/api_references.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ callbacks
3030
BackboneFinetuning
3131
BaseFinetuning
3232
BasePredictionWriter
33+
BatchSizeFinder
3334
Callback
3435
DeviceStatsMonitor
3536
EarlyStopping

docs/source-pytorch/extensions/callbacks.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ Lightning has a few built-in callbacks.
8787
BackboneFinetuning
8888
BaseFinetuning
8989
BasePredictionWriter
90+
BatchSizeFinder
9091
Callback
9192
DeviceStatsMonitor
9293
EarlyStopping

src/pytorch_lightning/CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
99

1010
### Added
1111

12+
- Added `BatchSizeFinder` callback ([#11089](https://github.com/PyTorchLightning/pytorch-lightning/pull/11089))
13+
14+
15+
- Tuner now supports a new `method` argument which will determine when to run the `BatchSizeFinder`: one of `fit`, `validate`, `test` or `predict` ([#11089](https://github.com/PyTorchLightning/pytorch-lightning/pull/11089))
16+
1217

1318
- Added prefix to log message in `seed_everything` with rank info ([#13290](https://github.com/Lightning-AI/lightning/issues/13290))
1419

src/pytorch_lightning/callbacks/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from pytorch_lightning.callbacks.batch_size_finder import BatchSizeFinder
1415
from pytorch_lightning.callbacks.callback import Callback
1516
from pytorch_lightning.callbacks.checkpoint import Checkpoint
1617
from pytorch_lightning.callbacks.device_stats_monitor import DeviceStatsMonitor
@@ -32,6 +33,8 @@
3233
__all__ = [
3334
"BackboneFinetuning",
3435
"BaseFinetuning",
36+
"BasePredictionWriter",
37+
"BatchSizeFinder",
3538
"Callback",
3639
"Checkpoint",
3740
"DeviceStatsMonitor",
@@ -42,7 +45,6 @@
4245
"ModelCheckpoint",
4346
"ModelPruning",
4447
"ModelSummary",
45-
"BasePredictionWriter",
4648
"ProgressBarBase",
4749
"QuantizationAwareTraining",
4850
"RichModelSummary",
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
r"""
15+
BatchSizeFinder
16+
===============
17+
18+
Finds optimal batch size
19+
"""
20+
21+
from typing import Optional
22+
23+
import pytorch_lightning as pl
24+
from pytorch_lightning.callbacks.callback import Callback
25+
from pytorch_lightning.tuner.batch_size_scaling import scale_batch_size
26+
from pytorch_lightning.utilities.exceptions import _TunerExitException, MisconfigurationException
27+
from pytorch_lightning.utilities.parsing import lightning_hasattr
28+
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
29+
30+
31+
class BatchSizeFinder(Callback):
32+
SUPPORTED_MODES = ("power", "binsearch")
33+
34+
def __init__(
35+
self,
36+
mode: str = "power",
37+
steps_per_trial: int = 3,
38+
init_val: int = 2,
39+
max_trials: int = 25,
40+
batch_arg_name: str = "batch_size",
41+
) -> None:
42+
"""The ``BatchSizeFinder`` callback tries to find the largest batch size for a given model that does not
43+
give an out of memory (OOM) error. All you need to do is add it as a callback inside Trainer and call
44+
``trainer.{fit,validate,test,predict}``. Internally it calls the respective step function
45+
``steps_per_trial`` times for each batch size until one of the batch sizes generates an OOM error.
46+
47+
Args:
48+
mode: search strategy to update the batch size:
49+
50+
- ``'power'``: Keep multiplying the batch size by 2, until we get an OOM error.
51+
- ``'binsearch'``: Initially keep multiplying by 2 and after encountering an OOM error
52+
do a binary search between the last successful batch size and the batch size that failed.
53+
54+
steps_per_trial: number of steps to run with a given batch size.
55+
Ideally 1 should be enough to test if an OOM error occurs,
56+
however in practice a few are needed.
57+
58+
init_val: initial batch size to start the search with.
59+
60+
max_trials: max number of increases in batch size done before
61+
algorithm is terminated
62+
63+
batch_arg_name: name of the attribute that stores the batch size.
64+
It is expected that the user has provided a model or datamodule that has a hyperparameter
65+
with that name. We will look for this attribute name in the following places
66+
67+
- ``model``
68+
- ``model.hparams``
69+
- ``trainer.datamodule`` (the datamodule passed to the tune method)
70+
"""
71+
mode = mode.lower()
72+
if mode not in self.SUPPORTED_MODES:
73+
raise ValueError(f"`mode` should be either of {self.SUPPORTED_MODES}")
74+
self.optimal_batch_size = init_val
75+
self._mode = mode
76+
self._steps_per_trial = steps_per_trial
77+
self._init_val = init_val
78+
self._max_trials = max_trials
79+
self._batch_arg_name = batch_arg_name
80+
self._early_exit = False
81+
82+
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
83+
if trainer._accelerator_connector.is_distributed:
84+
raise MisconfigurationException("The Batch size finder is not supported with distributed strategies.")
85+
86+
running_stage = trainer.state.stage
87+
assert running_stage is not None
88+
dl_source = getattr(trainer._data_connector, f"_{running_stage.dataloader_prefix}_dataloader_source")
89+
90+
# TODO: check if this can be enabled (#4040)
91+
if not trainer._data_connector._train_dataloader_source.is_module():
92+
raise MisconfigurationException(
93+
"The Batch size finder cannot be used with dataloaders passed directly to `.fit()`. Please disable"
94+
" the feature or incorporate the dataloader into your LightningModule or LightningDataModule."
95+
)
96+
97+
# TODO: Add support for multiple eval dataloader
98+
if stage != "fit":
99+
dataloaders = dl_source.dataloader()
100+
if isinstance(dataloaders, list) and len(dataloaders) > 1:
101+
raise MisconfigurationException(
102+
f"The Batch size finder cannot be used with multiple {running_stage.dataloader_prefix} dataloaders."
103+
)
104+
105+
if not lightning_hasattr(pl_module, self._batch_arg_name):
106+
raise MisconfigurationException(
107+
f"Field {self._batch_arg_name} not found in both `model` and `model.hparams`"
108+
)
109+
110+
if (
111+
hasattr(pl_module, self._batch_arg_name)
112+
and hasattr(pl_module, "hparams")
113+
and self._batch_arg_name in pl_module.hparams
114+
):
115+
rank_zero_warn(
116+
f"Field `model.{self._batch_arg_name}` and `model.hparams.{self._batch_arg_name}` are mutually"
117+
f" exclusive! `model.{self._batch_arg_name}` will be used as the initial batch size for scaling."
118+
" If this is not the intended behavior, please remove either one."
119+
)
120+
121+
def scale_batch_size(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
122+
new_size = scale_batch_size(
123+
trainer,
124+
pl_module,
125+
self._mode,
126+
self._steps_per_trial,
127+
self._init_val,
128+
self._max_trials,
129+
self._batch_arg_name,
130+
)
131+
132+
self.optimal_batch_size = new_size
133+
if self._early_exit:
134+
raise _TunerExitException()
135+
136+
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
137+
self.scale_batch_size(trainer, pl_module)
138+
139+
def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
140+
if trainer.sanity_checking or trainer.state.fn != "validate":
141+
return
142+
143+
self.scale_batch_size(trainer, pl_module)
144+
145+
def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
146+
self.scale_batch_size(trainer, pl_module)
147+
148+
def on_predict_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
149+
self.scale_batch_size(trainer, pl_module)

src/pytorch_lightning/trainer/connectors/callback_connector.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
RichProgressBar,
2929
TQDMProgressBar,
3030
)
31+
from pytorch_lightning.callbacks.batch_size_finder import BatchSizeFinder
3132
from pytorch_lightning.callbacks.rich_model_summary import RichModelSummary
3233
from pytorch_lightning.callbacks.timer import Timer
3334
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@@ -229,19 +230,30 @@ def _attach_model_callbacks(self) -> None:
229230

230231
@staticmethod
231232
def _reorder_callbacks(callbacks: List[Callback]) -> List[Callback]:
232-
"""Moves all Checkpoint callbacks to the end of the list. The sequential order within the group of
233-
checkpoint callbacks is preserved, as well as the order of all other callbacks.
233+
"""Moves all the tuner specific callbacks at the beginning of the list and all the `ModelCheckpoint`
234+
callbacks to the end of the list. The sequential order within the group of checkpoint callbacks is
235+
preserved, as well as the order of all other callbacks.
234236
235237
Args:
236238
callbacks: A list of callbacks.
237239
238240
Return:
239-
A new list in which the last elements are Checkpoint if there were any present in the
240-
input.
241+
A new list in which the first elements are tuner specific callbacks and last elements are ModelCheckpoints
242+
if there were any present in the input.
241243
"""
242-
checkpoints: List[Callback] = [c for c in callbacks if isinstance(c, Checkpoint)]
243-
not_checkpoints = [c for c in callbacks if not isinstance(c, Checkpoint)]
244-
return not_checkpoints + checkpoints
244+
tuner_callbacks: List[Callback] = []
245+
other_callbacks: List[Callback] = []
246+
checkpoint_callbacks: List[Callback] = []
247+
248+
for cb in callbacks:
249+
if isinstance(cb, BatchSizeFinder):
250+
tuner_callbacks.append(cb)
251+
elif isinstance(cb, Checkpoint):
252+
checkpoint_callbacks.append(cb)
253+
else:
254+
other_callbacks.append(cb)
255+
256+
return tuner_callbacks + other_callbacks + checkpoint_callbacks
245257

246258

247259
def _configure_external_callbacks() -> List[Callback]:

src/pytorch_lightning/trainer/teardown.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from lightning_lite.utilities.distributed import distributed_available
1818
from pytorch_lightning.trainer.states import TrainerStatus
19+
from pytorch_lightning.utilities.exceptions import _TunerExitException
1920
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
2021

2122

@@ -34,6 +35,13 @@ def call_and_handle_interrupt(trainer: Any, trainer_fn: Callable, *args: Any, **
3435
return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
3536
else:
3637
return trainer_fn(*args, **kwargs)
38+
39+
except _TunerExitException:
40+
trainer._call_teardown_hook()
41+
trainer._teardown()
42+
trainer.state.status = TrainerStatus.FINISHED
43+
trainer.state.stage = None
44+
3745
# TODO: Unify both exceptions below, where `KeyboardError` doesn't re-raise
3846
except KeyboardInterrupt as exception:
3947
rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")

src/pytorch_lightning/trainer/trainer.py

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from torch import Tensor
4242
from torch.optim import Optimizer
4343
from torch.utils.data import DataLoader
44+
from typing_extensions import Literal
4445

4546
import pytorch_lightning as pl
4647
from lightning_lite.utilities.cloud_io import get_filesystem
@@ -889,9 +890,11 @@ def tune(
889890
model: "pl.LightningModule",
890891
train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None,
891892
val_dataloaders: Optional[EVAL_DATALOADERS] = None,
893+
dataloaders: Optional[EVAL_DATALOADERS] = None,
892894
datamodule: Optional[LightningDataModule] = None,
893895
scale_batch_size_kwargs: Optional[Dict[str, Any]] = None,
894896
lr_find_kwargs: Optional[Dict[str, Any]] = None,
897+
method: Literal["fit", "validate", "test", "predict"] = "fit",
895898
) -> _TunerResult:
896899
r"""
897900
Runs routines to tune hyperparameters before training.
@@ -905,44 +908,34 @@ def tune(
905908
906909
val_dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them specifying validation samples.
907910
911+
dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them specifying val/test/predict
912+
samples used for running tuner on validation/testing/prediction.
913+
908914
datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`.
909915
910916
scale_batch_size_kwargs: Arguments for :func:`~pytorch_lightning.tuner.batch_size_scaling.scale_batch_size`
911917
912918
lr_find_kwargs: Arguments for :func:`~pytorch_lightning.tuner.lr_finder.lr_find`
919+
920+
method: Method to run tuner on. It can be any of ``("fit", "validate", "test", "predict")``.
913921
"""
914922
if not isinstance(model, pl.LightningModule):
915923
raise TypeError(f"`Trainer.tune()` requires a `LightningModule`, got: {model.__class__.__qualname__}")
916924

917925
Trainer._log_api_event("tune")
918926

919-
self.state.fn = TrainerFn.TUNING
920-
self.state.status = TrainerStatus.RUNNING
921-
self.tuning = True
922-
923-
# if a datamodule comes in as the second arg, then fix it for the user
924-
if isinstance(train_dataloaders, LightningDataModule):
925-
datamodule = train_dataloaders
926-
train_dataloaders = None
927-
# If you supply a datamodule you can't supply train_dataloader or val_dataloaders
928-
if (train_dataloaders is not None or val_dataloaders is not None) and datamodule is not None:
929-
raise MisconfigurationException(
930-
"You cannot pass `train_dataloader` or `val_dataloaders` to `trainer.tune(datamodule=...)`"
931-
)
932-
933-
# links data to the trainer
934-
self._data_connector.attach_data(
935-
model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule
936-
)
937-
938927
with isolate_rng():
939928
result = self.tuner._tune(
940-
model, scale_batch_size_kwargs=scale_batch_size_kwargs, lr_find_kwargs=lr_find_kwargs
929+
model,
930+
train_dataloaders,
931+
val_dataloaders,
932+
dataloaders,
933+
datamodule,
934+
scale_batch_size_kwargs=scale_batch_size_kwargs,
935+
lr_find_kwargs=lr_find_kwargs,
936+
method=method,
941937
)
942938

943-
assert self.state.stopped
944-
self.tuning = False
945-
946939
return result
947940

948941
def _restore_modules_and_callbacks(self, checkpoint_path: Optional[_PATH] = None) -> None:

0 commit comments

Comments
 (0)