Skip to content

Commit 9cc4695

Browse files
committed
Revert "Add BatchSizeFinder callback (#11089)"
This reverts commit d1a3a3e.
1 parent 9fc4ff3 commit 9cc4695

File tree

15 files changed

+169
-640
lines changed

15 files changed

+169
-640
lines changed

docs/source-pytorch/api_references.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ callbacks
3030
BackboneFinetuning
3131
BaseFinetuning
3232
BasePredictionWriter
33-
BatchSizeFinder
3433
Callback
3534
DeviceStatsMonitor
3635
EarlyStopping

docs/source-pytorch/extensions/callbacks.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,6 @@ Lightning has a few built-in callbacks.
8787
BackboneFinetuning
8888
BaseFinetuning
8989
BasePredictionWriter
90-
BatchSizeFinder
9190
Callback
9291
DeviceStatsMonitor
9392
EarlyStopping

src/pytorch_lightning/CHANGELOG.md

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
99

1010
### Added
1111

12-
- Added `BatchSizeFinder` callback ([#11089](https://github.com/PyTorchLightning/pytorch-lightning/pull/11089))
13-
14-
15-
- Tuner now supports a new `method` argument which will determine when to run the `BatchSizeFinder`: one of `fit`, `validate`, `test` or `predict` ([#11089](https://github.com/PyTorchLightning/pytorch-lightning/pull/11089))
16-
1712

1813
- Added prefix to log message in `seed_everything` with rank info ([#13290](https://github.com/Lightning-AI/lightning/issues/13290))
1914

src/pytorch_lightning/callbacks/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from pytorch_lightning.callbacks.batch_size_finder import BatchSizeFinder
1514
from pytorch_lightning.callbacks.callback import Callback
1615
from pytorch_lightning.callbacks.checkpoint import Checkpoint
1716
from pytorch_lightning.callbacks.device_stats_monitor import DeviceStatsMonitor
@@ -33,8 +32,6 @@
3332
__all__ = [
3433
"BackboneFinetuning",
3534
"BaseFinetuning",
36-
"BasePredictionWriter",
37-
"BatchSizeFinder",
3835
"Callback",
3936
"Checkpoint",
4037
"DeviceStatsMonitor",
@@ -45,6 +42,7 @@
4542
"ModelCheckpoint",
4643
"ModelPruning",
4744
"ModelSummary",
45+
"BasePredictionWriter",
4846
"ProgressBarBase",
4947
"QuantizationAwareTraining",
5048
"RichModelSummary",

src/pytorch_lightning/callbacks/batch_size_finder.py

Lines changed: 0 additions & 149 deletions
This file was deleted.

src/pytorch_lightning/trainer/connectors/callback_connector.py

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
RichProgressBar,
2929
TQDMProgressBar,
3030
)
31-
from pytorch_lightning.callbacks.batch_size_finder import BatchSizeFinder
3231
from pytorch_lightning.callbacks.rich_model_summary import RichModelSummary
3332
from pytorch_lightning.callbacks.timer import Timer
3433
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@@ -230,30 +229,19 @@ def _attach_model_callbacks(self) -> None:
230229

231230
@staticmethod
232231
def _reorder_callbacks(callbacks: List[Callback]) -> List[Callback]:
233-
"""Moves all the tuner specific callbacks at the beginning of the list and all the `ModelCheckpoint`
234-
callbacks to the end of the list. The sequential order within the group of checkpoint callbacks is
235-
preserved, as well as the order of all other callbacks.
232+
"""Moves all Checkpoint callbacks to the end of the list. The sequential order within the group of
233+
checkpoint callbacks is preserved, as well as the order of all other callbacks.
236234
237235
Args:
238236
callbacks: A list of callbacks.
239237
240238
Return:
241-
A new list in which the first elements are tuner specific callbacks and last elements are ModelCheckpoints
242-
if there were any present in the input.
239+
A new list in which the last elements are Checkpoint if there were any present in the
240+
input.
243241
"""
244-
tuner_callbacks: List[Callback] = []
245-
other_callbacks: List[Callback] = []
246-
checkpoint_callbacks: List[Callback] = []
247-
248-
for cb in callbacks:
249-
if isinstance(cb, BatchSizeFinder):
250-
tuner_callbacks.append(cb)
251-
elif isinstance(cb, Checkpoint):
252-
checkpoint_callbacks.append(cb)
253-
else:
254-
other_callbacks.append(cb)
255-
256-
return tuner_callbacks + other_callbacks + checkpoint_callbacks
242+
checkpoints: List[Callback] = [c for c in callbacks if isinstance(c, Checkpoint)]
243+
not_checkpoints = [c for c in callbacks if not isinstance(c, Checkpoint)]
244+
return not_checkpoints + checkpoints
257245

258246

259247
def _configure_external_callbacks() -> List[Callback]:

src/pytorch_lightning/trainer/teardown.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
from lightning_lite.utilities.distributed import distributed_available
1818
from pytorch_lightning.trainer.states import TrainerStatus
19-
from pytorch_lightning.utilities.exceptions import _TunerExitException
2019
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
2120

2221

@@ -35,13 +34,6 @@ def call_and_handle_interrupt(trainer: Any, trainer_fn: Callable, *args: Any, **
3534
return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
3635
else:
3736
return trainer_fn(*args, **kwargs)
38-
39-
except _TunerExitException:
40-
trainer._call_teardown_hook()
41-
trainer._teardown()
42-
trainer.state.status = TrainerStatus.FINISHED
43-
trainer.state.stage = None
44-
4537
# TODO: Unify both exceptions below, where `KeyboardError` doesn't re-raise
4638
except KeyboardInterrupt as exception:
4739
rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")

src/pytorch_lightning/trainer/trainer.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
from torch import Tensor
4242
from torch.optim import Optimizer
4343
from torch.utils.data import DataLoader
44-
from typing_extensions import Literal
4544

4645
import pytorch_lightning as pl
4746
from lightning_lite.utilities.cloud_io import get_filesystem
@@ -890,11 +889,9 @@ def tune(
890889
model: "pl.LightningModule",
891890
train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None,
892891
val_dataloaders: Optional[EVAL_DATALOADERS] = None,
893-
dataloaders: Optional[EVAL_DATALOADERS] = None,
894892
datamodule: Optional[LightningDataModule] = None,
895893
scale_batch_size_kwargs: Optional[Dict[str, Any]] = None,
896894
lr_find_kwargs: Optional[Dict[str, Any]] = None,
897-
method: Literal["fit", "validate", "test", "predict"] = "fit",
898895
) -> _TunerResult:
899896
r"""
900897
Runs routines to tune hyperparameters before training.
@@ -908,34 +905,44 @@ def tune(
908905
909906
val_dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them specifying validation samples.
910907
911-
dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them specifying val/test/predict
912-
samples used for running tuner on validation/testing/prediction.
913-
914908
datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`.
915909
916910
scale_batch_size_kwargs: Arguments for :func:`~pytorch_lightning.tuner.batch_size_scaling.scale_batch_size`
917911
918912
lr_find_kwargs: Arguments for :func:`~pytorch_lightning.tuner.lr_finder.lr_find`
919-
920-
method: Method to run tuner on. It can be any of ``("fit", "validate", "test", "predict")``.
921913
"""
922914
if not isinstance(model, pl.LightningModule):
923915
raise TypeError(f"`Trainer.tune()` requires a `LightningModule`, got: {model.__class__.__qualname__}")
924916

925917
Trainer._log_api_event("tune")
926918

919+
self.state.fn = TrainerFn.TUNING
920+
self.state.status = TrainerStatus.RUNNING
921+
self.tuning = True
922+
923+
# if a datamodule comes in as the second arg, then fix it for the user
924+
if isinstance(train_dataloaders, LightningDataModule):
925+
datamodule = train_dataloaders
926+
train_dataloaders = None
927+
# If you supply a datamodule you can't supply train_dataloader or val_dataloaders
928+
if (train_dataloaders is not None or val_dataloaders is not None) and datamodule is not None:
929+
raise MisconfigurationException(
930+
"You cannot pass `train_dataloader` or `val_dataloaders` to `trainer.tune(datamodule=...)`"
931+
)
932+
933+
# links data to the trainer
934+
self._data_connector.attach_data(
935+
model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule
936+
)
937+
927938
with isolate_rng():
928939
result = self.tuner._tune(
929-
model,
930-
train_dataloaders,
931-
val_dataloaders,
932-
dataloaders,
933-
datamodule,
934-
scale_batch_size_kwargs=scale_batch_size_kwargs,
935-
lr_find_kwargs=lr_find_kwargs,
936-
method=method,
940+
model, scale_batch_size_kwargs=scale_batch_size_kwargs, lr_find_kwargs=lr_find_kwargs
937941
)
938942

943+
assert self.state.stopped
944+
self.tuning = False
945+
939946
return result
940947

941948
def _restore_modules_and_callbacks(self, checkpoint_path: Optional[_PATH] = None) -> None:

0 commit comments

Comments
 (0)