Skip to content

Commit f5e9a5a

Browse files
committed
keep tune and remove early_exit
1 parent 9c6bcc1 commit f5e9a5a

File tree

4 files changed

+116
-149
lines changed

4 files changed

+116
-149
lines changed

pytorch_lightning/callbacks/batch_size_finder.py

+7-84
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from pytorch_lightning.callbacks.base import Callback
3030
from pytorch_lightning.loggers.base import DummyLogger
3131
from pytorch_lightning.trainer.states import TrainerFn
32+
from pytorch_lightning.tuner.tuning import _TunerExitException
3233
from pytorch_lightning.utilities.cloud_io import get_filesystem
3334
from pytorch_lightning.utilities.data import has_len_all_ranks
3435
from pytorch_lightning.utilities.distributed import rank_zero_info
@@ -46,7 +47,6 @@ def __init__(
4647
init_val=2,
4748
max_trials=25,
4849
batch_arg_name="batch_size",
49-
early_exit=False,
5050
):
5151
"""Callback try to find the largest batch size for a given model that does not give an out of memory (OOM)
5252
error. It works with both training and evalation. All you need to do is add it as a callback inside Trainer
@@ -56,7 +56,7 @@ def __init__(
5656
Args:
5757
mode: search strategy to update the batch size:
5858
59-
- ``'power'`` (default): Keep multiplying the batch size by 2, until we get an OOM error.
59+
- ``'power'``: Keep multiplying the batch size by 2, until we get an OOM error.
6060
- ``'binsearch'``: Initially keep multiplying by 2 and after encountering an OOM error
6161
do a binary search between the last successful batch size and the batch size that failed.
6262
@@ -76,9 +76,6 @@ def __init__(
7676
- ``model``
7777
- ``model.hparams``
7878
- ``trainer.datamodule`` (the datamodule passed to the tune method)
79-
80-
early_exit: whether to continue with the training/evaluation or stop after
81-
an optimal batch size has been found.
8279
"""
8380
supported_modes = ("power", "binsearch")
8481
mode = mode.lower()
@@ -91,7 +88,8 @@ def __init__(
9188
self.max_trials = max_trials
9289
self.batch_arg_name = batch_arg_name
9390
self.optimal_batch_size = init_val
94-
self.early_exit = early_exit
91+
92+
self._early_exit = False
9593

9694
def scale_batch_size(self, trainer, pl_module):
9795
if trainer.fast_dev_run:
@@ -165,6 +163,9 @@ def scale_batch_size(self, trainer, pl_module):
165163
print(f"new batch size: {new_size}")
166164
self.optimal_batch_size = new_size
167165

166+
if self._early_exit:
167+
raise _TunerExitException()
168+
168169
def _run_power_scaling(self, trainer, pl_module, new_size):
169170
"""Batch scaling mode where the size is doubled at each iteration until an OOM error is encountered."""
170171
for _ in range(self.max_trials):
@@ -332,99 +333,21 @@ def _restore_params(self, trainer):
332333
if "loop_verbose" in self._dumped_params:
333334
loop.verbose = self._dumped_params["loop_verbose"]
334335

335-
def pre_early_exit(self, trainer):
336-
if trainer.fast_dev_run:
337-
return
338-
339-
# this is required to stop the respective loops
340-
if trainer.state.fn == TrainerFn.FITTING:
341-
self._dumped_params["num_training_batches"] = trainer.num_training_batches
342-
trainer.num_training_batches = 0
343-
elif trainer.state.fn == TrainerFn.VALIDATING:
344-
self._dumped_params["num_val_batches"] = trainer.num_val_batches
345-
trainer.num_val_batches = [0]
346-
elif trainer.state.fn == TrainerFn.TESTING:
347-
self._dumped_params["num_test_batches"] = trainer.num_test_batches
348-
trainer.num_test_batches = [0]
349-
elif trainer.state.fn == TrainerFn.PREDICTING:
350-
self._dumped_params["num_predict_batches"] = trainer.num_predict_batches
351-
trainer.num_predict_batches = [0]
352-
353-
def post_early_exit(self, trainer):
354-
if trainer.fast_dev_run:
355-
return
356-
357-
# restore the state used to stop the respective loop
358-
if trainer.state.fn == TrainerFn.FITTING:
359-
trainer.num_training_batches = self._dumped_params["num_training_batches"]
360-
loop = trainer.fit_loop
361-
if trainer.state.fn == TrainerFn.VALIDATING:
362-
trainer.num_val_batches = self._dumped_params["num_val_batches"]
363-
loop = trainer.validate_loop
364-
if trainer.state.fn == TrainerFn.TESTING:
365-
trainer.num_test_batches = self._dumped_params["num_test_batches"]
366-
loop = trainer.test_loop
367-
if trainer.state.fn == TrainerFn.PREDICTING:
368-
trainer.num_predict_batches = self._dumped_params["num_predict_batches"]
369-
loop = trainer.predict_loop
370-
371-
loop.load_state_dict(self._dumped_params["loop_state_dict"], force_load_progress=True)
372-
trainer.callbacks = [cb for cb in trainer.callbacks if not isinstance(cb, BatchSizeFinder)]
373-
374336
def on_fit_start(self, trainer, pl_module):
375337
self.scale_batch_size(trainer, pl_module)
376338

377-
if self.early_exit:
378-
self.pre_early_exit(trainer)
379-
else:
380-
trainer.callbacks = [cb for cb in trainer.callbacks if not isinstance(cb, BatchSizeFinder)]
381-
382339
def on_validation_start(self, trainer, pl_module):
383340
if trainer.sanity_checking or trainer.state.fn != TrainerFn.VALIDATING:
384341
return
385342

386343
self.scale_batch_size(trainer, pl_module)
387344

388-
if self.early_exit:
389-
self.pre_early_exit(trainer)
390-
else:
391-
trainer.callbacks = [cb for cb in trainer.callbacks if not isinstance(cb, BatchSizeFinder)]
392-
393345
def on_test_start(self, trainer, pl_module):
394346
self.scale_batch_size(trainer, pl_module)
395347

396-
if self.early_exit:
397-
self.pre_early_exit(trainer)
398-
else:
399-
trainer.callbacks = [cb for cb in trainer.callbacks if not isinstance(cb, BatchSizeFinder)]
400-
401348
def on_predict_start(self, trainer, pl_module):
402349
self.scale_batch_size(trainer, pl_module)
403350

404-
if self.early_exit:
405-
self.pre_early_exit(trainer)
406-
else:
407-
trainer.callbacks = [cb for cb in trainer.callbacks if not isinstance(cb, BatchSizeFinder)]
408-
409-
def on_fit_end(self, trainer, pl_module):
410-
if self.early_exit:
411-
self.post_early_exit(trainer)
412-
413-
def on_validation_end(self, trainer, pl_module):
414-
if trainer.sanity_checking or trainer.state.fn != TrainerFn.VALIDATING:
415-
return
416-
417-
if self.early_exit:
418-
self.post_early_exit(trainer)
419-
420-
def on_test_end(self, trainer, pl_module):
421-
if self.early_exit:
422-
self.post_early_exit(trainer)
423-
424-
def on_predict_end(self, trainer, pl_module):
425-
if self.early_exit:
426-
self.post_early_exit(trainer)
427-
428351
def _adjust_batch_size(
429352
self,
430353
trainer: "pl.Trainer",

pytorch_lightning/trainer/connectors/callback_connector.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,6 @@ def _reorder_callbacks(callbacks: List[Callback]) -> List[Callback]:
306306
checkpoints = [c for c in callbacks if isinstance(c, ModelCheckpoint)]
307307
not_checkpoints = [c for c in callbacks if not isinstance(c, ModelCheckpoint)]
308308
callbacks = not_checkpoints + checkpoints
309-
batch_size_finder_callback = [c for c in callbacks if isinstance(c, BatchSizeFinder)]
310-
other_callbacks = [c for c in callbacks if not isinstance(c, BatchSizeFinder)]
311-
return batch_size_finder_callback + other_callbacks
309+
tuner_callbacks = [c for c in callbacks if isinstance(c, BatchSizeFinder)]
310+
non_tuner_callbacks = [c for c in callbacks if not isinstance(c, BatchSizeFinder)]
311+
return tuner_callbacks + non_tuner_callbacks

pytorch_lightning/trainer/trainer.py

+23-24
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@
7171
from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin
7272
from pytorch_lightning.trainer.states import RunningStage, TrainerFn, TrainerState, TrainerStatus
7373
from pytorch_lightning.tuner.lr_finder import _LRFinder
74-
from pytorch_lightning.tuner.tuning import Tuner
74+
from pytorch_lightning.tuner.tuning import _TunerExitException, Tuner
7575
from pytorch_lightning.utilities import (
7676
_AcceleratorType,
7777
_IPU_AVAILABLE,
@@ -678,6 +678,19 @@ def _call_and_handle_interrupt(self, trainer_fn: Callable, *args: Any, **kwargs:
678678
return spawn_output.trainer_results
679679
else:
680680
return trainer_fn(*args, **kwargs)
681+
682+
except _TunerExitException as exception:
683+
self.state.status = TrainerStatus.FINISHED
684+
if distributed_available() and self.world_size > 1:
685+
# try syncing remaing processes, kill otherwise
686+
self.strategy.reconciliate_processes(traceback.format_exc())
687+
self._on_exception()
688+
# reset bookkeeping
689+
self.state.stage = None
690+
self._call_callback_hooks("on_exception", exception)
691+
# shutdown workers
692+
self._data_connector.teardown()
693+
681694
# TODO: treat KeyboardInterrupt as BaseException (delete the code below) in v1.7
682695
except KeyboardInterrupt as exception:
683696
rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
@@ -1027,9 +1040,11 @@ def tune(
10271040
model: "pl.LightningModule",
10281041
train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None,
10291042
val_dataloaders: Optional[EVAL_DATALOADERS] = None,
1043+
dataloaders: Optional[EVAL_DATALOADERS] = None,
10301044
datamodule: Optional[LightningDataModule] = None,
10311045
scale_batch_size_kwargs: Optional[Dict[str, Any]] = None,
10321046
lr_find_kwargs: Optional[Dict[str, Any]] = None,
1047+
method="fit",
10331048
) -> Dict[str, Optional[Union[int, _LRFinder]]]:
10341049
r"""
10351050
Runs routines to tune hyperparameters before training.
@@ -1043,44 +1058,28 @@ def tune(
10431058
10441059
val_dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them specifying validation samples.
10451060
1061+
dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them specifying val/test/predict
1062+
samples used for running tuner on validation/testing/prediction.
1063+
10461064
datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`.
10471065
10481066
scale_batch_size_kwargs: Arguments for :func:`~pytorch_lightning.tuner.batch_size_scaling.scale_batch_size`
10491067
10501068
lr_find_kwargs: Arguments for :func:`~pytorch_lightning.tuner.lr_finder.lr_find`
1051-
"""
1052-
Trainer._log_api_event("tune")
1053-
self.state.fn = TrainerFn.TUNING
1054-
self.state.status = TrainerStatus.RUNNING
1055-
self.tuning = True
1056-
1057-
# if a datamodule comes in as the second arg, then fix it for the user
1058-
if isinstance(train_dataloaders, LightningDataModule):
1059-
datamodule = train_dataloaders
1060-
train_dataloaders = None
1061-
# If you supply a datamodule you can't supply train_dataloader or val_dataloaders
1062-
if (train_dataloaders is not None or val_dataloaders is not None) and datamodule is not None:
1063-
raise MisconfigurationException(
1064-
"You cannot pass `train_dataloader` or `val_dataloaders` to `trainer.tune(datamodule=...)`"
1065-
)
1066-
1067-
# links data to the trainer
1068-
self._data_connector.attach_data(
1069-
model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule
1070-
)
10711069
1070+
method: Method to run tuner on. It can be ``'fit', 'validate', 'test', 'predict'``
1071+
"""
10721072
result = self.tuner._tune(
10731073
model,
10741074
train_dataloaders,
10751075
val_dataloaders,
1076+
dataloaders,
10761077
datamodule,
10771078
scale_batch_size_kwargs=scale_batch_size_kwargs,
10781079
lr_find_kwargs=lr_find_kwargs,
1080+
method=method,
10791081
)
10801082

1081-
assert self.state.stopped
1082-
self.tuning = False
1083-
10841083
return result
10851084

10861085
def _restore_modules_and_callbacks(self, checkpoint_path: Optional[_PATH] = None) -> None:

0 commit comments

Comments
 (0)