diff --git a/docs/source-pytorch/advanced/training_tricks.rst b/docs/source-pytorch/advanced/training_tricks.rst
index 456bf027445c9..6d39c1550ee20 100644
--- a/docs/source-pytorch/advanced/training_tricks.rst
+++ b/docs/source-pytorch/advanced/training_tricks.rst
@@ -78,18 +78,24 @@ Auto-scaling of batch size can be enabled to find the largest batch size that fi
memory. Large batch size often yields a better estimation of the gradients, but may also result in
longer training time. Inspired by https://github.com/BlackHC/toma.
-.. seealso:: :class:`~pytorch_lightning.trainer.trainer.Trainer`
+.. seealso:: :class:`~pytorch_lightning.tuner.tuning.Tuner`
.. code-block:: python
- # DEFAULT (ie: don't scale batch size automatically)
- trainer = Trainer(auto_scale_batch_size=None)
+ from pytorch_lightning.tuner import Tuner
+
+ # Create a tuner for the trainer
+ trainer = Trainer(...)
+ tuner = Tuner(trainer)
- # Autoscale batch size
- trainer = Trainer(auto_scale_batch_size=None | "power" | "binsearch")
+ # Auto-scale batch size by growing it exponentially (default)
+ tuner.scale_batch_size(model, mode="power")
- # Find the batch size
- trainer.tune(model)
+ # Auto-scale batch size with binary search
+ tuner.scale_batch_size(model, mode="binsearch")
+
+ # Fit as normal with new batch size
+ trainer.fit(model)
Currently, this feature supports two modes ``'power'`` scaling and ``'binsearch'``
@@ -122,9 +128,10 @@ search for batch sizes larger than the size of the training dataset.
return DataLoader(train_dataset, batch_size=self.batch_size | self.hparams.batch_size)
- trainer = Trainer(...)
model = LitModel(batch_size=32)
- trainer.tune(model)
+ trainer = Trainer(...)
+ tuner = Tuner(trainer)
+ tuner.scale_batch_size(model)
# using LightningDataModule
class LitDataModule(LightningDataModule):
@@ -138,10 +145,12 @@ search for batch sizes larger than the size of the training dataset.
return DataLoader(train_dataset, batch_size=self.batch_size | self.hparams.batch_size)
- trainer = Trainer(...)
model = MyModel()
datamodule = LitDataModule(batch_size=32)
- trainer.tune(model, datamodule=datamodule)
+
+ trainer = Trainer(...)
+ tuner = Tuner(trainer)
+ tuner.scale_batch_size(model, datamodule=datamodule)
Note that the ``train_dataloader`` can be either part of
the ``LightningModule`` or ``LightningDataModule``
@@ -149,29 +158,6 @@ search for batch sizes larger than the size of the training dataset.
and the ``LightningDataModule`` contain a ``train_dataloader``,
the ``LightningDataModule`` takes precedence.
-.. warning::
-
- Due to the constraints listed above, this features does *NOT* work when passing dataloaders directly
- to ``.fit()``.
-
-The scaling algorithm has a number of parameters that the user can control by
-invoking the :meth:`~pytorch_lightning.tuner.tuning.Tuner.scale_batch_size` method:
-
-.. code-block:: python
-
- # Use default in trainer construction
- trainer = Trainer()
- tuner = Tuner(trainer)
-
- # Invoke method
- new_batch_size = tuner.scale_batch_size(model, *extra_parameters_here)
-
- # Override old batch size (this is done automatically)
- model.hparams.batch_size = new_batch_size
-
- # Fit as normal
- trainer.fit(model)
-
The algorithm in short works by:
1. Dumping the current state of the model and trainer
2. Iteratively until convergence or maximum number of tries ``max_trials`` (default 25) has been reached:
@@ -247,14 +233,6 @@ Customizing Batch Size Finder
Learning Rate Finder
********************
-.. raw:: html
-
-
-
-|
-
For training deep neural networks, selecting a good learning rate is essential
for both better performance and faster convergence. Even optimizers such as
:class:`~torch.optim.Adam` that are self-adjusting the learning rate can benefit from more optimal
@@ -284,16 +262,17 @@ Using Lightning's built-in LR finder
To enable the learning rate finder, your :doc:`lightning module <../common/lightning_module>` needs to
have a ``learning_rate`` or ``lr`` attribute (or as a field in your ``hparams`` i.e.
-``hparams.learning_rate`` or ``hparams.lr``). Then, set ``Trainer(auto_lr_find=True)``
-during trainer construction, and then call ``trainer.tune(model)`` to run the LR finder.
+``hparams.learning_rate`` or ``hparams.lr``). Then, create the :class:`~pytorch_lightning.tuner.tuning.Tuner` via ``tuner = Tuner(trainer)``
+and call ``tuner.lr_find(model)`` to run the LR finder.
The suggested ``learning_rate`` will be written to the console and will be automatically
set to your :doc:`lightning module <../common/lightning_module>`, which can be accessed
via ``self.learning_rate`` or ``self.lr``.
-.. seealso:: :ref:`trainer.tune `.
-
.. code-block:: python
+ from pytorch_lightning.tuner import Tuner
+
+
class LitModel(LightningModule):
def __init__(self, learning_rate):
super().__init__()
@@ -305,36 +284,39 @@ via ``self.learning_rate`` or ``self.lr``.
model = LitModel()
+ trainer = Trainer(...)
+
+ # Create a Tuner
+ tuner = Tuner(trainer)
# finds learning rate automatically
# sets hparams.lr or hparams.learning_rate to that learning rate
- trainer = Trainer(auto_lr_find=True)
-
- trainer.tune(model)
+ tuner.lr_find(model)
-If your model is using an arbitrary value instead of ``self.lr`` or ``self.learning_rate``, set that value as ``auto_lr_find``:
+If your model is using an arbitrary value instead of ``self.lr`` or ``self.learning_rate``, set that value in ``lr_find``:
.. code-block:: python
model = LitModel()
+ trainer = Trainer(...)
+ tuner = Tuner(trainer)
# to set to your own hparams.my_value
- trainer = Trainer(auto_lr_find="my_value")
+ tuner.lr_find(model, attr_name="my_value")
- trainer.tune(model)
You can also inspect the results of the learning rate finder or just play around
-with the parameters of the algorithm. This can be done by invoking the
-:meth:`~pytorch_lightning.tuner.tuning.Tuner.lr_find` method. A typical example of this would look like:
+with the parameters of the algorithm. A typical example of this would look like:
.. code-block:: python
model = MyModelClass(hparams)
trainer = Trainer()
+ tuner = Tuner(trainer)
# Run learning rate finder
- lr_finder = trainer.tuner.lr_find(model)
+ lr_finder = tuner.lr_find(model)
# Results can be found in
print(lr_finder.results)
diff --git a/docs/source-pytorch/common/trainer.rst b/docs/source-pytorch/common/trainer.rst
index 9e1590e8a2e89..d9286d8828913 100644
--- a/docs/source-pytorch/common/trainer.rst
+++ b/docs/source-pytorch/common/trainer.rst
@@ -287,69 +287,6 @@ Example::
# no accumulation for epochs 1-4. accumulate 3 for epochs 5-10. accumulate 20 after that
trainer = Trainer(accumulate_grad_batches={5: 3, 10: 20})
-auto_scale_batch_size
-^^^^^^^^^^^^^^^^^^^^^
-
-.. raw:: html
-
-
-
-|
-
-Automatically tries to find the largest batch size that fits into memory,
-before any training.
-
-.. code-block:: python
-
- # default used by the Trainer (no scaling of batch size)
- trainer = Trainer(auto_scale_batch_size=None)
-
- # run batch size scaling, result overrides hparams.batch_size
- trainer = Trainer(auto_scale_batch_size="binsearch")
-
- # call tune to find the batch size
- trainer.tune(model)
-
-
-auto_lr_find
-^^^^^^^^^^^^
-
-.. raw:: html
-
-
-
-|
-
-Runs a learning rate finder algorithm (see this `paper `_)
-when calling trainer.tune(), to find optimal initial learning rate.
-
-.. code-block:: python
-
- # default used by the Trainer (no learning rate finder)
- trainer = Trainer(auto_lr_find=False)
-
-Example::
-
- # run learning rate finder, results override hparams.learning_rate
- trainer = Trainer(auto_lr_find=True)
-
- # call tune to find the lr
- trainer.tune(model)
-
-Example::
-
- # run learning rate finder, results override hparams.my_lr_arg
- trainer = Trainer(auto_lr_find='my_lr_arg')
-
- # call tune to find the lr
- trainer.tune(model)
-
-.. note::
- See the :ref:`learning rate finder guide `.
benchmark
^^^^^^^^^
@@ -617,7 +554,7 @@ impact to subsequent runs. These are the changes enabled:
- The :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` callbacks will not trigger.
- The :class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping` callbacks will not trigger.
- Sets ``limit_{train,val,test,predict}_batches`` to 1 or the number passed.
-- Disables the Tuner.
+- Disables the tuning callbacks (:class:`~pytorch_lightning.callbacks.batch_size_finder.BatchSizeFinder`, :class:`~pytorch_lightning.callbacks.lr_finder.LearningRateFinder`).
- If using the CLI, the configuration file is not saved.
@@ -1358,12 +1295,6 @@ predict
.. automethod:: pytorch_lightning.trainer.Trainer.predict
:noindex:
-tune
-****
-
-.. automethod:: pytorch_lightning.trainer.Trainer.tune
- :noindex:
-
Properties
^^^^^^^^^^
@@ -1523,11 +1454,11 @@ execution within that function, and the status of the Trainer.
.. code-block:: python
- # fn in ("fit", "validate", "test", "predict", "tune")
+ # fn in ("fit", "validate", "test", "predict")
trainer.state.fn
# status in ("initializing", "running", "finished", "interrupted")
trainer.state.status
- # stage in ("train", "sanity_check", "validate", "test", "predict", "tune")
+ # stage in ("train", "sanity_check", "validate", "test", "predict")
trainer.state.stage
should_stop
diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md
index e7c0239e3a75c..ada6f9215c740 100644
--- a/src/pytorch_lightning/CHANGELOG.md
+++ b/src/pytorch_lightning/CHANGELOG.md
@@ -15,6 +15,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added an argument `include_cuda` in `pytorch_lightning.utilities.seed.isolate_rng` to disable managing `torch.cuda`'s rng ([#16423](https://github.com/Lightning-AI/lightning/pull/16423))
+- Added `Tuner.lr_find(attr_name=...)` to specify custom learning rate attribute names ([#16462](https://github.com/Lightning-AI/lightning/pull/16462))
+
- Added an `OnExceptionCheckpoint` callback to save a checkpoint on exception ([#16512](https://github.com/Lightning-AI/lightning/pull/16512))
- Added support for running the `MLFlowLogger` with the `mlflow-skinny` package ([16513](https://github.com/Lightning-AI/lightning/pull/16513))
@@ -31,6 +33,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* `pytorch_lightning.plugins.precision.native_amp` is now `pytorch_lightning.plugins.precision.amp`
* `NativeSyncBatchNorm` is now `TorchSyncBatchNorm`
+- Changed the default of `LearningRateFinder(update_attr=...)` and `Tuner.lr_find(update_attr=...)` to `True` ([#16462](https://github.com/Lightning-AI/lightning/pull/16462))
+
- Renamed the `pl.utilities.exceptions.GracefulExitException` to `SIGTERMException` ([#16501](https://github.com/Lightning-AI/lightning/pull/16501))
### Deprecated
@@ -142,9 +146,14 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed support for passing a dictionary value to `self.log()` ([#16389](https://github.com/Lightning-AI/lightning/pull/16389))
+- Removed `Trainer.model` setter ([#16462](https://github.com/Lightning-AI/lightning/pull/16462))
+
- Tuner removal
* Removed the deprecated `trainer.tuning` property ([#16379](https://github.com/Lightning-AI/lightning/pull/16379))
* Removed the deprecated `TrainerFn.TUNING` and `RunningStage.TUNING` enums ([#16379](https://github.com/Lightning-AI/lightning/pull/16379))
+ * Removed `Trainer.tune()` in favor of `Tuner(trainer).{lr_find,scale_batch_size}` ([#16462](https://github.com/Lightning-AI/lightning/pull/16462))
+ * Removed `Trainer(auto_scale_batch_size=...)` in favor of `Tuner(trainer).scale_batch_size()` ([#16462](https://github.com/Lightning-AI/lightning/pull/16462))
+ * Removed `Trainer(auto_lr_find=...)` in favor of `Tuner(trainer).lr_find()` ([#16462](https://github.com/Lightning-AI/lightning/pull/16462))
### Fixed
diff --git a/src/pytorch_lightning/callbacks/batch_size_finder.py b/src/pytorch_lightning/callbacks/batch_size_finder.py
index b0286910b81dc..002a2a498ea5a 100644
--- a/src/pytorch_lightning/callbacks/batch_size_finder.py
+++ b/src/pytorch_lightning/callbacks/batch_size_finder.py
@@ -22,7 +22,7 @@
import pytorch_lightning as pl
from pytorch_lightning.callbacks.callback import Callback
-from pytorch_lightning.tuner.batch_size_scaling import scale_batch_size
+from pytorch_lightning.tuner.batch_size_scaling import _scale_batch_size
from pytorch_lightning.utilities.exceptions import _TunerExitException, MisconfigurationException
from pytorch_lightning.utilities.parsing import lightning_hasattr
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
@@ -165,7 +165,7 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: O
)
def scale_batch_size(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
- new_size = scale_batch_size(
+ new_size = _scale_batch_size(
trainer,
pl_module,
self._mode,
diff --git a/src/pytorch_lightning/callbacks/lr_finder.py b/src/pytorch_lightning/callbacks/lr_finder.py
index 1c950e64086b9..8a60c87de50f0 100644
--- a/src/pytorch_lightning/callbacks/lr_finder.py
+++ b/src/pytorch_lightning/callbacks/lr_finder.py
@@ -21,7 +21,7 @@
import pytorch_lightning as pl
from pytorch_lightning.callbacks.callback import Callback
-from pytorch_lightning.tuner.lr_finder import _LRFinder, lr_find
+from pytorch_lightning.tuner.lr_finder import _lr_find, _LRFinder
from pytorch_lightning.utilities.exceptions import _TunerExitException
from pytorch_lightning.utilities.seed import isolate_rng
@@ -32,11 +32,8 @@ class LearningRateFinder(Callback):
Args:
min_lr: Minimum learning rate to investigate
-
max_lr: Maximum learning rate to investigate
-
num_training_steps: Number of learning rates to test
-
mode: Search strategy to update learning rate after each batch:
- ``'exponential'`` (default): Increases the learning rate exponentially.
@@ -45,8 +42,9 @@ class LearningRateFinder(Callback):
early_stop_threshold: Threshold for stopping the search. If the
loss at any point is larger than early_stop_threshold*best_loss
then the search is stopped. To disable, set to None.
-
update_attr: Whether to update the learning rate attribute or not.
+ attr_name: Name of the attribute which stores the learning rate. The names 'learning_rate' or 'lr' get
+ automatically detected. Otherwise, set the name here.
Example::
@@ -73,8 +71,8 @@ def on_train_epoch_start(self, trainer, pl_module):
Raises:
MisconfigurationException:
- If learning rate/lr in ``model`` or ``model.hparams`` isn't overridden when ``auto_lr_find=True``,
- or if you are using more than one optimizer.
+ If learning rate/lr in ``model`` or ``model.hparams`` isn't overridden, or if you are using more than
+ one optimizer.
"""
SUPPORTED_MODES = ("linear", "exponential")
@@ -86,7 +84,8 @@ def __init__(
num_training_steps: int = 100,
mode: str = "exponential",
early_stop_threshold: Optional[float] = 4.0,
- update_attr: bool = False,
+ update_attr: bool = True,
+ attr_name: str = "",
) -> None:
mode = mode.lower()
if mode not in self.SUPPORTED_MODES:
@@ -98,13 +97,14 @@ def __init__(
self._mode = mode
self._early_stop_threshold = early_stop_threshold
self._update_attr = update_attr
+ self._attr_name = attr_name
self._early_exit = False
self.lr_finder: Optional[_LRFinder] = None
def lr_find(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
with isolate_rng():
- self.optimal_lr = lr_find(
+ self.optimal_lr = _lr_find(
trainer,
pl_module,
min_lr=self._min_lr,
@@ -113,6 +113,7 @@ def lr_find(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> Non
mode=self._mode,
early_stop_threshold=self._early_stop_threshold,
update_attr=self._update_attr,
+ attr_name=self._attr_name,
)
if self._early_exit:
diff --git a/src/pytorch_lightning/cli.py b/src/pytorch_lightning/cli.py
index 1930e2300b41c..d3f9f2a3e0460 100644
--- a/src/pytorch_lightning/cli.py
+++ b/src/pytorch_lightning/cli.py
@@ -436,7 +436,6 @@ def subcommands() -> Dict[str, Set[str]]:
"validate": {"model", "dataloaders", "datamodule"},
"test": {"model", "dataloaders", "datamodule"},
"predict": {"model", "dataloaders", "datamodule"},
- "tune": {"model", "train_dataloaders", "val_dataloaders", "datamodule"},
}
def _add_subcommands(self, parser: LightningArgumentParser, **kwargs: Any) -> None:
diff --git a/src/pytorch_lightning/trainer/trainer.py b/src/pytorch_lightning/trainer/trainer.py
index 3f895ef814e9d..459a6e3ccb3b6 100644
--- a/src/pytorch_lightning/trainer/trainer.py
+++ b/src/pytorch_lightning/trainer/trainer.py
@@ -40,7 +40,6 @@
from torch import Tensor
from torch.optim import Optimizer
from torch.utils.data import DataLoader
-from typing_extensions import Literal
import pytorch_lightning as pl
from lightning_fabric.utilities.apply_func import convert_tensors_to_scalars
@@ -78,7 +77,6 @@
from pytorch_lightning.trainer.connectors.signal_connector import SignalConnector
from pytorch_lightning.trainer.states import RunningStage, TrainerFn, TrainerState, TrainerStatus
from pytorch_lightning.trainer.supporters import CombinedLoader
-from pytorch_lightning.tuner.tuning import _TunerResult, Tuner
from pytorch_lightning.utilities import GradClipAlgorithmType, parsing
from pytorch_lightning.utilities.argparse import (
_defaults_from_env_vars,
@@ -148,10 +146,8 @@ def __init__(
benchmark: Optional[bool] = None,
deterministic: Optional[Union[bool, _LITERAL_WARN]] = None,
reload_dataloaders_every_n_epochs: int = 0,
- auto_lr_find: Union[bool, str] = False,
replace_sampler_ddp: bool = True,
detect_anomaly: bool = False,
- auto_scale_batch_size: Union[str, bool] = False,
plugins: Optional[Union[PLUGIN_INPUT, List[PLUGIN_INPUT]]] = None,
multiple_trainloader_mode: str = "max_size_cycle",
inference_mode: bool = True,
@@ -167,31 +163,6 @@ def __init__(
accumulate_grad_batches: Accumulates grads every k batches or as set up in the dict.
Default: ``None``.
- auto_lr_find: If set to True, will make trainer.tune() run a learning rate finder,
- trying to optimize initial learning for faster convergence. trainer.tune() method will
- set the suggested learning rate in self.lr or self.learning_rate in the LightningModule.
- To use a different key set a string instead of True with the key name.
- Default: ``False``.
-
- auto_scale_batch_size: If set to True, will `initially` run a batch size
- finder trying to find the largest batch size that fits into memory.
- The result will be stored in self.batch_size in the LightningModule
- or LightningDataModule depending on your setup.
- Additionally, can be set to either `power` that estimates the batch size through
- a power search or `binsearch` that estimates the batch size through a binary search.
- Default: ``False``.
-
- auto_select_gpus: If enabled and ``gpus`` or ``devices`` is an integer, pick available
- gpus automatically. This is especially useful when
- GPUs are configured to be in "exclusive mode", such
- that only one process at a time can access them.
- Default: ``False``.
-
- .. deprecated:: v1.9
- ``auto_select_gpus`` has been deprecated in v1.9.0 and will be removed in v2.0.0.
- Please use the function :func:`~lightning_fabric.accelerators.cuda.find_usable_cuda_devices`
- instead.
-
benchmark: The value (``True`` or ``False``) to set ``torch.backends.cudnn.benchmark`` to.
The value for ``torch.backends.cudnn.benchmark`` set in the current session will be used
(``False`` if not manually set). If :paramref:`~pytorch_lightning.trainer.Trainer.deterministic` is set
@@ -365,7 +336,6 @@ def __init__(
self._callback_connector = CallbackConnector(self)
self._checkpoint_connector = CheckpointConnector(self)
self._signal_connector = SignalConnector(self)
- self.tuner = Tuner(self)
# init loops
self.fit_loop = _FitLoop(min_epochs=min_epochs, max_epochs=max_epochs)
@@ -429,9 +399,6 @@ def __init__(
self._detect_anomaly: bool = detect_anomaly
self._setup_on_init()
- # configure tuner
- self.tuner.on_trainer_init(auto_lr_find, auto_scale_batch_size)
-
# configure profiler
setup._init_profiler(self, profiler)
@@ -846,57 +813,6 @@ def _predict_impl(
return results
- def tune(
- self,
- model: "pl.LightningModule",
- train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None,
- val_dataloaders: Optional[EVAL_DATALOADERS] = None,
- dataloaders: Optional[EVAL_DATALOADERS] = None,
- datamodule: Optional[LightningDataModule] = None,
- scale_batch_size_kwargs: Optional[Dict[str, Any]] = None,
- lr_find_kwargs: Optional[Dict[str, Any]] = None,
- method: Literal["fit", "validate", "test", "predict"] = "fit",
- ) -> _TunerResult:
- r"""
- Runs routines to tune hyperparameters before training.
-
- Args:
- model: Model to tune.
-
- train_dataloaders: A collection of :class:`torch.utils.data.DataLoader` or a
- :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying training samples.
- In the case of multiple dataloaders, please see this :ref:`section `.
-
- val_dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them specifying validation samples.
-
- dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them specifying val/test/predict
- samples used for running tuner on validation/testing/prediction.
-
- datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`.
-
- scale_batch_size_kwargs: Arguments for :func:`~pytorch_lightning.tuner.batch_size_scaling.scale_batch_size`
-
- lr_find_kwargs: Arguments for :func:`~pytorch_lightning.tuner.lr_finder.lr_find`
-
- method: Method to run tuner on. It can be any of ``("fit", "validate", "test", "predict")``.
- """
- model = self._maybe_unwrap_optimized(model)
- Trainer._log_api_event("tune")
-
- with isolate_rng():
- result = self.tuner._tune(
- model,
- train_dataloaders,
- val_dataloaders,
- dataloaders,
- datamodule,
- scale_batch_size_kwargs=scale_batch_size_kwargs,
- lr_find_kwargs=lr_find_kwargs,
- method=method,
- )
-
- return result
-
def _run(
self, model: "pl.LightningModule", ckpt_path: Optional[_PATH] = None
) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]:
@@ -1665,17 +1581,6 @@ def model(self) -> Optional[torch.nn.Module]:
"""
return self.strategy.model
- @model.setter
- def model(self, model: torch.nn.Module) -> None:
- """Setter for the model, pass-through to accelerator and plugin where the model reference is stored. Used
- by the Tuner to reset the state of Trainer and Accelerator.
-
- Args:
- model: The LightningModule, possibly wrapped into DataParallel or DistributedDataParallel, depending
- on the backend.
- """
- self.strategy.model = model
-
"""
General properties
"""
diff --git a/src/pytorch_lightning/tuner/__init__.py b/src/pytorch_lightning/tuner/__init__.py
index e69de29bb2d1d..dc816988a1ad3 100644
--- a/src/pytorch_lightning/tuner/__init__.py
+++ b/src/pytorch_lightning/tuner/__init__.py
@@ -0,0 +1,14 @@
+# Copyright The Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from pytorch_lightning.tuner.tuning import Tuner # noqa: F401
diff --git a/src/pytorch_lightning/tuner/batch_size_scaling.py b/src/pytorch_lightning/tuner/batch_size_scaling.py
index 0557383cbb406..f6c2a3eead5a0 100644
--- a/src/pytorch_lightning/tuner/batch_size_scaling.py
+++ b/src/pytorch_lightning/tuner/batch_size_scaling.py
@@ -25,7 +25,7 @@
log = logging.getLogger(__name__)
-def scale_batch_size(
+def _scale_batch_size(
trainer: "pl.Trainer",
model: "pl.LightningModule",
mode: str = "power",
@@ -34,6 +34,32 @@ def scale_batch_size(
max_trials: int = 25,
batch_arg_name: str = "batch_size",
) -> Optional[int]:
+ """Iteratively try to find the largest batch size for a given model that does not give an out of memory (OOM)
+ error.
+
+ Args:
+ trainer: A Trainer instance.
+ model: Model to tune.
+ mode: Search strategy to update the batch size:
+
+ - ``'power'``: Keep multiplying the batch size by 2, until we get an OOM error.
+ - ``'binsearch'``: Initially keep multiplying by 2 and after encountering an OOM error
+ do a binary search between the last successful batch size and the batch size that failed.
+
+ steps_per_trial: number of steps to run with a given batch size.
+ Ideally 1 should be enough to test if an OOM error occurs,
+ however in practise a few are needed
+ init_val: initial batch size to start the search with
+ max_trials: max number of increases in batch size done before
+ algorithm is terminated
+ batch_arg_name: name of the attribute that stores the batch size.
+ It is expected that the user has provided a model or datamodule that has a hyperparameter
+ with that name. We will look for this attribute name in the following places
+
+ - ``model``
+ - ``model.hparams``
+ - ``trainer.datamodule`` (the datamodule passed to the tune method)
+ """
if trainer.fast_dev_run:
rank_zero_warn("Skipping batch size scaler since `fast_dev_run` is enabled.")
return None
diff --git a/src/pytorch_lightning/tuner/lr_finder.py b/src/pytorch_lightning/tuner/lr_finder.py
index 341c9d8bb1726..1075bdacdb526 100644
--- a/src/pytorch_lightning/tuner/lr_finder.py
+++ b/src/pytorch_lightning/tuner/lr_finder.py
@@ -44,23 +44,24 @@
log = logging.getLogger(__name__)
-def _determine_lr_attr_name(trainer: "pl.Trainer", model: "pl.LightningModule") -> str:
- if isinstance(trainer.auto_lr_find, str):
- if not lightning_hasattr(model, trainer.auto_lr_find):
- raise MisconfigurationException(
- f"`auto_lr_find` was set to {trainer.auto_lr_find}, however"
+def _determine_lr_attr_name(model: "pl.LightningModule", attr_name: str = "") -> str:
+ if attr_name:
+ if not lightning_hasattr(model, attr_name):
+ raise AttributeError(
+ f"The attribute name for the learning rate was set to {attr_name}, but"
" could not find this as a field in `model` or `model.hparams`."
)
- return trainer.auto_lr_find
+ return attr_name
attr_options = ("lr", "learning_rate")
for attr in attr_options:
if lightning_hasattr(model, attr):
return attr
- raise MisconfigurationException(
- "When `auto_lr_find=True`, either `model` or `model.hparams` should"
- f" have one of these fields: {attr_options} overridden."
+ raise AttributeError(
+ "When using the learning rate finder, either `model` or `model.hparams` should"
+ f" have one of these fields: {attr_options}. If your model has a different name for the learning rate, set"
+ f" it with `.lr_find(attr_name=...)`."
)
@@ -201,7 +202,7 @@ def suggestion(self, skip_begin: int = 10, skip_end: int = 1) -> Optional[float]
return self.results["lr"][self._optimal_idx]
-def lr_find(
+def _lr_find(
trainer: "pl.Trainer",
model: "pl.LightningModule",
min_lr: float = 1e-8,
@@ -210,15 +211,36 @@ def lr_find(
mode: str = "exponential",
early_stop_threshold: Optional[float] = 4.0,
update_attr: bool = False,
+ attr_name: str = "",
) -> Optional[_LRFinder]:
- """See :meth:`~pytorch_lightning.tuner.tuning.Tuner.lr_find`"""
+ """Enables the user to do a range test of good initial learning rates, to reduce the amount of guesswork in
+ picking a good starting learning rate.
+
+ Args:
+ trainer: A Trainer instance.
+ model: Model to tune.
+ min_lr: minimum learning rate to investigate
+ max_lr: maximum learning rate to investigate
+ num_training: number of learning rates to test
+ mode: Search strategy to update learning rate after each batch:
+
+ - ``'exponential'``: Increases the learning rate exponentially.
+ - ``'linear'``: Increases the learning rate linearly.
+
+ early_stop_threshold: Threshold for stopping the search. If the
+ loss at any point is larger than early_stop_threshold*best_loss
+ then the search is stopped. To disable, set to None.
+ update_attr: Whether to update the learning rate attribute or not.
+ attr_name: Name of the attribute which stores the learning rate. The names 'learning_rate' or 'lr' get
+ automatically detected. Otherwise, set the name here.
+ """
if trainer.fast_dev_run:
rank_zero_warn("Skipping learning rate finder since `fast_dev_run` is enabled.")
return None
# Determine lr attr
if update_attr:
- lr_attr_name = _determine_lr_attr_name(trainer, model)
+ attr_name = _determine_lr_attr_name(model, attr_name)
# Save initial model, that is loaded after learning rate is found
ckpt_path = os.path.join(trainer.default_root_dir, f".lr_find_{uuid.uuid4()}.ckpt")
@@ -266,7 +288,7 @@ def lr_find(
# TODO: log lr.results to self.logger
if lr is not None:
- lightning_setattr(model, lr_attr_name, lr)
+ lightning_setattr(model, attr_name, lr)
log.info(f"Learning rate set to {lr}")
# Restore initial state of model
@@ -284,8 +306,6 @@ def __lr_finder_dump_params(trainer: "pl.Trainer") -> Dict[str, Any]:
"optimizer_frequencies": trainer.strategy.optimizer_frequencies,
"callbacks": trainer.callbacks,
"loggers": trainer.loggers,
- # TODO: check if this is required
- "auto_lr_find": trainer.auto_lr_find,
"max_steps": trainer.fit_loop.max_steps,
"limit_val_batches": trainer.limit_val_batches,
"loop_state_dict": deepcopy(trainer.fit_loop.state_dict()),
@@ -297,8 +317,6 @@ def __lr_finder_reset_params(trainer: "pl.Trainer", num_training: int, early_sto
trainer.strategy.lr_scheduler_configs = []
trainer.strategy.optimizer_frequencies = []
- # avoid lr find being called multiple times
- trainer.auto_lr_find = False
# Use special lr logger callback
trainer.callbacks = [_LRCallback(num_training, early_stop_threshold, progress_bar_refresh_rate=1)]
# No logging
@@ -312,7 +330,6 @@ def __lr_finder_restore_params(trainer: "pl.Trainer", params: Dict[str, Any]) ->
trainer.strategy.optimizers = params["optimizers"]
trainer.strategy.lr_scheduler_configs = params["lr_scheduler_configs"]
trainer.strategy.optimizer_frequencies = params["optimizer_frequencies"]
- trainer.auto_lr_find = params["auto_lr_find"]
trainer.callbacks = params["callbacks"]
trainer.loggers = params["loggers"]
trainer.fit_loop.max_steps = params["max_steps"]
diff --git a/src/pytorch_lightning/tuner/tuning.py b/src/pytorch_lightning/tuner/tuning.py
index b4dc0c654280a..daaab0e227c1e 100644
--- a/src/pytorch_lightning/tuner/tuning.py
+++ b/src/pytorch_lightning/tuner/tuning.py
@@ -11,105 +11,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, Optional, Union
+from typing import Optional, Union
-from typing_extensions import Literal, NotRequired, TypedDict
+from typing_extensions import Literal
import pytorch_lightning as pl
-from pytorch_lightning.callbacks.batch_size_finder import BatchSizeFinder
from pytorch_lightning.callbacks.callback import Callback
-from pytorch_lightning.callbacks.lr_finder import LearningRateFinder
-from pytorch_lightning.core.datamodule import LightningDataModule
-from pytorch_lightning.trainer.states import TrainerStatus
-from pytorch_lightning.tuner.lr_finder import _LRFinder
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
-class _TunerResult(TypedDict):
- lr_find: NotRequired[Optional[_LRFinder]]
- scale_batch_size: NotRequired[Optional[int]]
-
-
class Tuner:
"""Tuner class to tune your model."""
def __init__(self, trainer: "pl.Trainer") -> None:
- self.trainer = trainer
-
- def on_trainer_init(self, auto_lr_find: Union[str, bool], auto_scale_batch_size: Union[str, bool]) -> None:
- self.trainer.auto_lr_find = auto_lr_find
- self.trainer.auto_scale_batch_size = auto_scale_batch_size
-
- def _tune(
- self,
- model: "pl.LightningModule",
- train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None,
- val_dataloaders: Optional[EVAL_DATALOADERS] = None,
- dataloaders: Optional[EVAL_DATALOADERS] = None,
- datamodule: Optional[LightningDataModule] = None,
- scale_batch_size_kwargs: Optional[Dict[str, Any]] = None,
- lr_find_kwargs: Optional[Dict[str, Any]] = None,
- method: Literal["fit", "validate", "test", "predict"] = "fit",
- ) -> _TunerResult:
- scale_batch_size_kwargs = scale_batch_size_kwargs or {}
- lr_find_kwargs = lr_find_kwargs or {}
- # return a dict instead of a tuple so BC is not broken if a new tuning procedure is added
- result = _TunerResult()
-
- self.trainer.strategy.connect(model)
-
- is_tuning = self.trainer.auto_scale_batch_size
- if self.trainer._accelerator_connector.is_distributed and is_tuning:
- raise MisconfigurationException(
- "`trainer.tune()` is currently not supported with"
- f" `Trainer(strategy={self.trainer.strategy.strategy_name!r})`."
- )
-
- # Run auto batch size scaling
- if self.trainer.auto_scale_batch_size:
- if isinstance(self.trainer.auto_scale_batch_size, str):
- scale_batch_size_kwargs.setdefault("mode", self.trainer.auto_scale_batch_size)
-
- result["scale_batch_size"] = self.scale_batch_size(
- model, train_dataloaders, val_dataloaders, dataloaders, datamodule, method, **scale_batch_size_kwargs
- )
-
- # Run learning rate finder:
- if self.trainer.auto_lr_find:
- self.trainer.state.status = TrainerStatus.RUNNING
-
- # TODO: Remove this once LRFinder is converted to a Callback
- # if a datamodule comes in as the second arg, then fix it for the user
- if isinstance(train_dataloaders, LightningDataModule):
- datamodule = train_dataloaders
- train_dataloaders = None
-
- # If you supply a datamodule you can't supply train_dataloader or val_dataloaders
- if (train_dataloaders is not None or val_dataloaders is not None) and datamodule is not None:
- raise MisconfigurationException(
- "You cannot pass `train_dataloader` or `val_dataloaders` to `trainer.tune()`"
- " if datamodule is already passed."
- )
-
- # links da_a to the trainer
- self.trainer._data_connector.attach_data(
- model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule
- )
-
- lr_find_kwargs.setdefault("update_attr", True)
- result["lr_find"] = self.lr_find(
- model, train_dataloaders, val_dataloaders, dataloaders, datamodule, method, **lr_find_kwargs
- )
- self.trainer.state.status = TrainerStatus.FINISHED
-
- return result
-
- def _run(self, *args: Any, **kwargs: Any) -> None:
- """`_run` wrapper to set the proper state during tuning, as this can be called multiple times."""
- self.trainer.state.status = TrainerStatus.RUNNING # last `_run` call might have set it to `FINISHED`
- self.trainer.training = True
- self.trainer._run(*args, **kwargs)
+ self._trainer = trainer
def scale_batch_size(
self,
@@ -130,20 +46,14 @@ def scale_batch_size(
Args:
model: Model to tune.
-
train_dataloaders: A collection of :class:`torch.utils.data.DataLoader` or a
:class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying training samples.
In the case of multiple dataloaders, please see this :ref:`section `.
-
val_dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them specifying validation samples.
-
dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them specifying val/test/predict
samples used for running tuner on validation/testing/prediction.
-
datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`.
-
method: Method to run tuner on. It can be any of ``("fit", "validate", "test", "predict")``.
-
mode: Search strategy to update the batch size:
- ``'power'``: Keep multiplying the batch size by 2, until we get an OOM error.
@@ -153,12 +63,9 @@ def scale_batch_size(
steps_per_trial: number of steps to run with a given batch size.
Ideally 1 should be enough to test if an OOM error occurs,
however in practise a few are needed
-
init_val: initial batch size to start the search with
-
max_trials: max number of increases in batch size done before
algorithm is terminated
-
batch_arg_name: name of the attribute that stores the batch size.
It is expected that the user has provided a model or datamodule that has a hyperparameter
with that name. We will look for this attribute name in the following places
@@ -167,7 +74,11 @@ def scale_batch_size(
- ``model.hparams``
- ``trainer.datamodule`` (the datamodule passed to the tune method)
"""
- _check_tuner_configuration(self.trainer, train_dataloaders, val_dataloaders, dataloaders, method)
+ _check_tuner_configuration(train_dataloaders, val_dataloaders, dataloaders, method)
+ _check_scale_batch_size_configuration(self._trainer)
+
+ # local import to avoid circular import
+ from pytorch_lightning.callbacks.batch_size_finder import BatchSizeFinder
batch_size_finder: Callback = BatchSizeFinder(
mode=mode,
@@ -176,21 +87,20 @@ def scale_batch_size(
max_trials=max_trials,
batch_arg_name=batch_arg_name,
)
- # do not continue with the loop in case trainer.tuner is used
+ # do not continue with the loop in case Tuner is used
batch_size_finder._early_exit = True
- self.trainer.callbacks = [batch_size_finder] + self.trainer.callbacks
+ self._trainer.callbacks = [batch_size_finder] + self._trainer.callbacks
if method == "fit":
- self.trainer.fit(model, train_dataloaders, val_dataloaders, datamodule)
+ self._trainer.fit(model, train_dataloaders, val_dataloaders, datamodule)
elif method == "validate":
- self.trainer.validate(model, dataloaders, datamodule=datamodule)
+ self._trainer.validate(model, dataloaders, datamodule=datamodule)
elif method == "test":
- self.trainer.test(model, dataloaders, datamodule=datamodule)
+ self._trainer.test(model, dataloaders, datamodule=datamodule)
elif method == "predict":
- self.trainer.predict(model, dataloaders, datamodule=datamodule)
+ self._trainer.predict(model, dataloaders, datamodule=datamodule)
- self.trainer.callbacks = [cb for cb in self.trainer.callbacks if cb is not batch_size_finder]
- self.trainer.auto_scale_batch_size = False
+ self._trainer.callbacks = [cb for cb in self._trainer.callbacks if cb is not batch_size_finder]
return batch_size_finder.optimal_batch_size
def lr_find(
@@ -206,31 +116,25 @@ def lr_find(
num_training: int = 100,
mode: str = "exponential",
early_stop_threshold: float = 4.0,
- update_attr: bool = False,
- ) -> Optional[_LRFinder]:
+ update_attr: bool = True,
+ attr_name: str = "",
+ ) -> Optional["pl.tuner.lr_finder._LRFinder"]:
"""Enables the user to do a range test of good initial learning rates, to reduce the amount of guesswork in
picking a good starting learning rate.
Args:
model: Model to tune.
-
train_dataloaders: A collection of :class:`torch.utils.data.DataLoader` or a
:class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying training samples.
In the case of multiple dataloaders, please see this :ref:`section `.
-
val_dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them specifying validation samples.
-
dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them specifying val/test/predict
samples used for running tuner on validation/testing/prediction.
-
datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`.
-
+ method: Method to run tuner on. It can be any of ``("fit", "validate", "test", "predict")``.
min_lr: minimum learning rate to investigate
-
max_lr: maximum learning rate to investigate
-
num_training: number of learning rates to test
-
mode: Search strategy to update learning rate after each batch:
- ``'exponential'``: Increases the learning rate exponentially.
@@ -239,18 +143,23 @@ def lr_find(
early_stop_threshold: Threshold for stopping the search. If the
loss at any point is larger than early_stop_threshold*best_loss
then the search is stopped. To disable, set to None.
-
update_attr: Whether to update the learning rate attribute or not.
+ attr_name: Name of the attribute which stores the learning rate. The names 'learning_rate' or 'lr' get
+ automatically detected. Otherwise, set the name here.
Raises:
MisconfigurationException:
- If learning rate/lr in ``model`` or ``model.hparams`` isn't overridden when ``auto_lr_find=True``,
+ If learning rate/lr in ``model`` or ``model.hparams`` isn't overridden,
or if you are using more than one optimizer.
"""
if method != "fit":
raise MisconfigurationException("method='fit' is an invalid configuration to run lr finder.")
- _check_tuner_configuration(self.trainer, train_dataloaders, val_dataloaders, dataloaders, method)
+ _check_tuner_configuration(train_dataloaders, val_dataloaders, dataloaders, method)
+ _check_lr_find_configuration(self._trainer)
+
+ # local import to avoid circular import
+ from pytorch_lightning.callbacks.lr_finder import LearningRateFinder
lr_finder_callback: Callback = LearningRateFinder(
min_lr=min_lr,
@@ -259,21 +168,20 @@ def lr_find(
mode=mode,
early_stop_threshold=early_stop_threshold,
update_attr=update_attr,
+ attr_name=attr_name,
)
lr_finder_callback._early_exit = True
- self.trainer.callbacks = [lr_finder_callback] + self.trainer.callbacks
+ self._trainer.callbacks = [lr_finder_callback] + self._trainer.callbacks
- self.trainer.fit(model, train_dataloaders, val_dataloaders, datamodule)
+ self._trainer.fit(model, train_dataloaders, val_dataloaders, datamodule)
- self.trainer.callbacks = [cb for cb in self.trainer.callbacks if cb is not lr_finder_callback]
+ self._trainer.callbacks = [cb for cb in self._trainer.callbacks if cb is not lr_finder_callback]
- self.trainer.auto_lr_find = False
return lr_finder_callback.optimal_lr
def _check_tuner_configuration(
- trainer: "pl.Trainer",
train_dataloaders: Optional[Union[TRAIN_DATALOADERS, "pl.LightningDataModule"]] = None,
val_dataloaders: Optional[EVAL_DATALOADERS] = None,
dataloaders: Optional[EVAL_DATALOADERS] = None,
@@ -296,19 +204,32 @@ def _check_tuner_configuration(
" arguments should be None, please consider setting `dataloaders` instead."
)
- configured_callbacks = []
- for cb in trainer.callbacks:
- if isinstance(cb, BatchSizeFinder) and trainer.auto_scale_batch_size:
- configured_callbacks.append("BatchSizeFinder")
- elif isinstance(cb, LearningRateFinder) and trainer.auto_lr_find:
- configured_callbacks.append("LearningRateFinder")
- if len(configured_callbacks) == 1:
- raise MisconfigurationException(
- f"Trainer is already configured with a `{configured_callbacks[0]}` callback."
+
+def _check_lr_find_configuration(trainer: "pl.Trainer") -> None:
+ # local import to avoid circular import
+ from pytorch_lightning.callbacks.lr_finder import LearningRateFinder
+
+ configured_callbacks = [cb for cb in trainer.callbacks if isinstance(cb, LearningRateFinder)]
+ if configured_callbacks:
+ raise ValueError(
+ "Trainer is already configured with a `LearningRateFinder` callback."
"Please remove it if you want to use the Tuner."
)
- elif len(configured_callbacks) == 2:
- raise MisconfigurationException(
- "Trainer is already configured with `LearningRateFinder` and `BatchSizeFinder` callbacks."
- " Please remove them if you want to use the Tuner."
+
+
+def _check_scale_batch_size_configuration(trainer: "pl.Trainer") -> None:
+ if trainer._accelerator_connector.is_distributed:
+ raise ValueError(
+ "Tuning the batch size is currently not supported with"
+ f" `Trainer(strategy={trainer.strategy.strategy_name!r})`."
+ )
+
+ # local import to avoid circular import
+ from pytorch_lightning.callbacks.batch_size_finder import BatchSizeFinder
+
+ configured_callbacks = [cb for cb in trainer.callbacks if isinstance(cb, BatchSizeFinder)]
+ if configured_callbacks:
+ raise ValueError(
+ "Trainer is already configured with a `BatchSizeFinder` callback."
+ "Please remove it if you want to use the Tuner."
)
diff --git a/tests/tests_pytorch/core/test_lightning_optimizer.py b/tests/tests_pytorch/core/test_lightning_optimizer.py
index 3a40be2074691..edfa6e44a2e58 100644
--- a/tests/tests_pytorch/core/test_lightning_optimizer.py
+++ b/tests/tests_pytorch/core/test_lightning_optimizer.py
@@ -21,6 +21,7 @@
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.demos.boring_classes import BoringModel
from pytorch_lightning.loops.optimization.optimizer_loop import Closure
+from pytorch_lightning.tuner.tuning import Tuner
@pytest.mark.parametrize("auto", (True, False))
@@ -54,9 +55,10 @@ def compare_optimizers():
model = BoringModel()
model.lr = 0.2
- trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, auto_lr_find=True)
+ trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)
+ tuner = Tuner(trainer)
- trainer.tune(model)
+ tuner.lr_find(model)
compare_optimizers()
trainer.fit(model)
diff --git a/tests/tests_pytorch/loggers/test_all.py b/tests/tests_pytorch/loggers/test_all.py
index 266cbeb05c5bc..78eb59be2cf19 100644
--- a/tests/tests_pytorch/loggers/test_all.py
+++ b/tests/tests_pytorch/loggers/test_all.py
@@ -32,6 +32,7 @@
)
from pytorch_lightning.loggers.logger import DummyExperiment
from pytorch_lightning.loggers.tensorboard import _TENSORBOARD_AVAILABLE
+from pytorch_lightning.tuner.tuning import Tuner
from tests_pytorch.helpers.runif import RunIf
from tests_pytorch.loggers.test_comet import _patch_comet_atexit
from tests_pytorch.loggers.test_mlflow import mock_mlflow_run_creation
@@ -201,14 +202,8 @@ def _test_loggers_pickle(tmpdir, monkeypatch, logger_class):
assert trainer2.logger.save_dir == logger.save_dir
-@pytest.mark.parametrize(
- "extra_params",
- [
- pytest.param(dict(max_epochs=1, auto_scale_batch_size=True), id="Batch-size-Finder"),
- pytest.param(dict(max_epochs=3, auto_lr_find=True), id="LR-Finder"),
- ],
-)
-def test_logger_reset_correctly(tmpdir, extra_params):
+@pytest.mark.parametrize("tuner_method", ["lr_find", "scale_batch_size"])
+def test_logger_reset_correctly(tmpdir, tuner_method):
"""Test that the tuners do not alter the logger reference."""
class CustomModel(BoringModel):
@@ -217,9 +212,11 @@ def __init__(self, lr=0.1, batch_size=1):
self.save_hyperparameters()
model = CustomModel()
- trainer = Trainer(default_root_dir=tmpdir, **extra_params)
+ trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)
+ tuner = Tuner(trainer)
+
logger1 = trainer.logger
- trainer.tune(model)
+ getattr(tuner, tuner_method)(model)
logger2 = trainer.logger
logger3 = model.logger
diff --git a/tests/tests_pytorch/models/test_hooks.py b/tests/tests_pytorch/models/test_hooks.py
index 1034bd14bd4cc..29c4e62f6220d 100644
--- a/tests/tests_pytorch/models/test_hooks.py
+++ b/tests/tests_pytorch/models/test_hooks.py
@@ -802,9 +802,6 @@ def test_trainer_model_hook_system_predict(tmpdir):
assert called == expected
-# TODO: add test for tune
-
-
def test_hooks_with_different_argument_names(tmpdir):
"""Test that argument names can be anything in the hooks."""
diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py
index 360720e6c05f9..3d7b0c4f65113 100644
--- a/tests/tests_pytorch/test_cli.py
+++ b/tests/tests_pytorch/test_cli.py
@@ -801,15 +801,6 @@ def predict(self, **_):
def after_predict(self):
self.called.append("after_predict")
- def before_tune(self):
- self.called.append("before_tune")
-
- def tune(self, **_):
- self.called.append("tune")
-
- def after_tune(self):
- self.called.append("after_tune")
-
with mock.patch("sys.argv", ["any.py", fn]):
cli = TestCLI(BoringModel)
assert cli.called == [f"before_{fn}", fn, f"after_{fn}"]
@@ -850,7 +841,7 @@ def subcommands():
TestCLI(BoringModel, trainer_class=TestTrainer)
out = out.getvalue()
assert "Sample extra function." in out
- assert "{fit,validate,test,predict,tune,foo}" in out
+ assert "{fit,validate,test,predict,foo}" in out
out = StringIO()
with mock.patch("sys.argv", ["any.py", "foo", "-h"]), redirect_stdout(out), pytest.raises(SystemExit):
@@ -1337,7 +1328,7 @@ def __init__(self, activation: torch.nn.Module = lazy_instance(torch.nn.LeakyReL
assert cli.model.activation is not model.activation
-def test_ddpstrategy_instantiation_and_find_unused_parameters():
+def test_ddpstrategy_instantiation_and_find_unused_parameters(mps_count_0):
strategy_default = lazy_instance(DDPStrategy, find_unused_parameters=True)
with mock.patch("sys.argv", ["any.py", "--trainer.strategy.process_group_backend=group"]):
cli = LightningCLI(
diff --git a/tests/tests_pytorch/trainer/flags/test_fast_dev_run.py b/tests/tests_pytorch/trainer/flags/test_fast_dev_run.py
index 0e54f2993e708..fa6ca71007e81 100644
--- a/tests/tests_pytorch/trainer/flags/test_fast_dev_run.py
+++ b/tests/tests_pytorch/trainer/flags/test_fast_dev_run.py
@@ -9,10 +9,10 @@
from pytorch_lightning.demos.boring_classes import BoringModel
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.loggers.logger import DummyLogger
+from pytorch_lightning.tuner.tuning import Tuner
-@pytest.mark.parametrize("tuner_alg", ["batch size scaler", "learning rate finder"])
-def test_skip_on_fast_dev_run_tuner(tmpdir, tuner_alg):
+def test_skip_on_fast_dev_run_tuner(tmpdir):
"""Test that tuner algorithms are skipped if fast dev run is enabled."""
model = BoringModel()
@@ -21,13 +21,15 @@ def test_skip_on_fast_dev_run_tuner(tmpdir, tuner_alg):
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=2,
- auto_scale_batch_size=(tuner_alg == "batch size scaler"),
- auto_lr_find=(tuner_alg == "learning rate finder"),
fast_dev_run=True,
)
- expected_message = f"Skipping {tuner_alg} since `fast_dev_run` is enabled."
- with pytest.warns(UserWarning, match=expected_message):
- trainer.tune(model)
+ tuner = Tuner(trainer)
+
+ with pytest.warns(UserWarning, match="Skipping learning rate finder since `fast_dev_run` is enabled."):
+ tuner.lr_find(model)
+
+ with pytest.warns(UserWarning, match="Skipping batch size scaler since `fast_dev_run` is enabled."):
+ tuner.scale_batch_size(model)
@pytest.mark.parametrize("fast_dev_run", [1, 4])
diff --git a/tests/tests_pytorch/trainer/test_states.py b/tests/tests_pytorch/trainer/test_states.py
index d45fb24f933c4..c884787d2811e 100644
--- a/tests/tests_pytorch/trainer/test_states.py
+++ b/tests/tests_pytorch/trainer/test_states.py
@@ -29,8 +29,6 @@ def test_initialize_state():
[pytest.param(dict(fast_dev_run=True), id="Fast-Run"), pytest.param(dict(max_steps=1), id="Single-Step")],
)
def test_trainer_fn_while_running(tmpdir, extra_params):
- trainer = Trainer(default_root_dir=tmpdir, **extra_params, auto_lr_find=True)
-
class TestModel(BoringModel):
def __init__(self, expected_fn, expected_stage):
super().__init__()
@@ -58,9 +56,7 @@ def on_test_batch_start(self, *_):
assert self.trainer.state.fn == self.expected_fn
assert self.trainer.testing
- model = TestModel(TrainerFn.FITTING, RunningStage.TRAINING)
- trainer.tune(model)
- assert trainer.state.finished
+ trainer = Trainer(default_root_dir=tmpdir, **extra_params)
model = TestModel(TrainerFn.FITTING, RunningStage.TRAINING)
trainer.fit(model)
diff --git a/tests/tests_pytorch/trainer/test_trainer.py b/tests/tests_pytorch/trainer/test_trainer.py
index 09cad4157a9ee..ae4e6268ce29d 100644
--- a/tests/tests_pytorch/trainer/test_trainer.py
+++ b/tests/tests_pytorch/trainer/test_trainer.py
@@ -72,10 +72,6 @@ def test_trainer_error_when_input_not_lightning_module():
run_method = getattr(trainer, method)
run_method(nn.Linear(2, 2))
- trainer = Trainer(auto_lr_find=True, auto_scale_batch_size=True)
- with pytest.raises(TypeError, match="must be a `LightningModule`.*got `Linear"):
- trainer.tune(nn.Linear(2, 2))
-
@pytest.mark.parametrize("url_ckpt", [True, False])
def test_no_val_module(monkeypatch, tmpdir, tmpdir_server, url_ckpt):
diff --git a/tests/tests_pytorch/trainer/test_trainer_cli.py b/tests/tests_pytorch/trainer/test_trainer_cli.py
index a28c88ad10a5e..455b2783670c7 100644
--- a/tests/tests_pytorch/trainer/test_trainer_cli.py
+++ b/tests/tests_pytorch/trainer/test_trainer_cli.py
@@ -112,14 +112,6 @@ def _raise():
@pytest.mark.parametrize(
["cli_args", "expected"],
[
- ("--auto_lr_find --auto_scale_batch_size power", {"auto_lr_find": True, "auto_scale_batch_size": "power"}),
- (
- "--auto_lr_find any_string --auto_scale_batch_size",
- {"auto_lr_find": "any_string", "auto_scale_batch_size": True},
- ),
- ("--auto_lr_find TRUE --auto_scale_batch_size FALSE", {"auto_lr_find": True, "auto_scale_batch_size": False}),
- ("--auto_lr_find t --auto_scale_batch_size ON", {"auto_lr_find": True, "auto_scale_batch_size": True}),
- ("--auto_lr_find 0 --auto_scale_batch_size n", {"auto_lr_find": False, "auto_scale_batch_size": False}),
(
"",
{
diff --git a/tests/tests_pytorch/tuner/test_lr_finder.py b/tests/tests_pytorch/tuner/test_lr_finder.py
index 1277f6939acfe..edbd1998b3e16 100644
--- a/tests/tests_pytorch/tuner/test_lr_finder.py
+++ b/tests/tests_pytorch/tuner/test_lr_finder.py
@@ -24,6 +24,7 @@
from pytorch_lightning.callbacks.lr_finder import LearningRateFinder
from pytorch_lightning.demos.boring_classes import BoringModel
from pytorch_lightning.tuner.lr_finder import _LRFinder
+from pytorch_lightning.tuner.tuning import Tuner
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests_pytorch.helpers.datamodules import ClassifDataModule
from tests_pytorch.helpers.runif import RunIf
@@ -48,22 +49,24 @@ def configure_optimizers(self):
# logger file to get meta
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)
+ tuner = Tuner(trainer)
with pytest.raises(MisconfigurationException, match="only works with single optimizer"):
- trainer.tuner.lr_find(model)
+ tuner.lr_find(model)
def test_model_reset_correctly(tmpdir):
- """Check that model weights are correctly reset after lr_find()"""
+ """Check that model weights are correctly reset after _lr_find()"""
model = BoringModel()
+ model.lr = 0.1
# logger file to get meta
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)
-
+ tuner = Tuner(trainer)
before_state_dict = deepcopy(model.state_dict())
- trainer.tuner.lr_find(model, num_training=5)
+ tuner.lr_find(model, num_training=5)
after_state_dict = model.state_dict()
@@ -79,13 +82,14 @@ def test_trainer_reset_correctly(tmpdir):
"""Check that all trainer parameters are reset correctly after lr_find()"""
model = BoringModel()
+ model.lr = 0.1
# logger file to get meta
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)
+ tuner = Tuner(trainer)
changed_attributes = [
"accumulate_grad_batches",
- "auto_lr_find",
"callbacks",
"checkpoint_callback",
"current_epoch",
@@ -99,7 +103,7 @@ def test_trainer_reset_correctly(tmpdir):
expected = {ca: getattr_recursive(trainer, ca) for ca in changed_attributes}
with no_warning_call(UserWarning, match="Please add the following callbacks"):
- trainer.tuner.lr_find(model, num_training=5)
+ tuner.lr_find(model, num_training=5)
actual = {ca: getattr_recursive(trainer, ca) for ca in changed_attributes}
assert actual == expected
@@ -107,8 +111,8 @@ def test_trainer_reset_correctly(tmpdir):
@pytest.mark.parametrize("use_hparams", [False, True])
-def test_trainer_arg_bool(tmpdir, use_hparams):
- """Test that setting trainer arg to bool works."""
+def test_tuner_lr_find(tmpdir, use_hparams):
+ """Test that lr_find updates the learning rate attribute."""
seed_everything(1)
class CustomBoringModel(BoringModel):
@@ -123,9 +127,10 @@ def configure_optimizers(self):
before_lr = 1e-2
model = CustomBoringModel(lr=before_lr)
- trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, auto_lr_find=True)
+ trainer = Trainer(default_root_dir=tmpdir, max_epochs=2)
+ tuner = Tuner(trainer)
+ tuner.lr_find(model, update_attr=True)
- trainer.tune(model)
if use_hparams:
after_lr = model.hparams.lr
else:
@@ -154,9 +159,9 @@ def configure_optimizers(self):
before_lr = 1e-2
model = CustomBoringModel(my_fancy_lr=before_lr)
- trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, auto_lr_find="my_fancy_lr")
-
- trainer.tune(model)
+ trainer = Trainer(default_root_dir=tmpdir, max_epochs=2)
+ tuner = Tuner(trainer)
+ tuner.lr_find(model, update_attr=True, attr_name="my_fancy_lr")
if use_hparams:
after_lr = model.hparams.my_fancy_lr
else:
@@ -188,11 +193,12 @@ def configure_optimizers(self):
model = CustomBoringModel(1e-2)
trainer = Trainer(default_root_dir=tmpdir, max_epochs=2)
- lrfinder = trainer.tuner.lr_find(model, mode="linear")
- after_lr = lrfinder.suggestion()
+ tuner = Tuner(trainer)
+ lr_finder = tuner.lr_find(model, mode="linear")
+ after_lr = lr_finder.suggestion()
assert after_lr is not None
model.hparams.lr = after_lr
- trainer.tune(model)
+ tuner.lr_find(model, update_attr=True)
assert after_lr is not None
assert before_lr != after_lr, "Learning rate was not altered after running learning rate finder"
@@ -211,8 +217,9 @@ def test_datamodule_parameter(tmpdir):
# logger file to get meta
trainer = Trainer(default_root_dir=tmpdir, max_epochs=2)
- lrfinder = trainer.tuner.lr_find(model, datamodule=dm)
- after_lr = lrfinder.suggestion()
+ tuner = Tuner(trainer)
+ lr_finder = tuner.lr_find(model, datamodule=dm)
+ after_lr = lr_finder.suggestion()
model.lr = after_lr
assert after_lr is not None
@@ -231,11 +238,12 @@ def __init__(self):
model = TestModel()
trainer = Trainer(default_root_dir=tmpdir, accumulate_grad_batches=2)
- lrfinder = trainer.tuner.lr_find(model, early_stop_threshold=None)
+ tuner = Tuner(trainer)
+ lr_finder = tuner.lr_find(model, early_stop_threshold=None)
- assert lrfinder.suggestion() != 1e-3
- assert len(lrfinder.results["lr"]) == 100
- assert lrfinder._total_batch_idx == 199
+ assert lr_finder.suggestion() != 1e-3
+ assert len(lr_finder.results["lr"]) == 100
+ assert lr_finder._total_batch_idx == 199
def test_suggestion_parameters_work(tmpdir):
@@ -254,10 +262,10 @@ def configure_optimizers(self):
# logger file to get meta
model = CustomBoringModel(lr=1e-2)
trainer = Trainer(default_root_dir=tmpdir, max_epochs=3)
-
- lrfinder = trainer.tuner.lr_find(model)
- lr1 = lrfinder.suggestion(skip_begin=10) # default
- lr2 = lrfinder.suggestion(skip_begin=70) # way too high, should have an impact
+ tuner = Tuner(trainer)
+ lr_finder = tuner.lr_find(model)
+ lr1 = lr_finder.suggestion(skip_begin=10) # default
+ lr2 = lr_finder.suggestion(skip_begin=70) # way too high, should have an impact
assert lr1 is not None
assert lr2 is not None
@@ -279,11 +287,12 @@ def configure_optimizers(self):
model = CustomBoringModel(lr=1e-2)
trainer = Trainer(default_root_dir=tmpdir, max_epochs=3)
+ tuner = Tuner(trainer)
+ lr_finder = tuner.lr_find(model)
- lrfinder = trainer.tuner.lr_find(model)
- before_lr = lrfinder.suggestion()
- lrfinder.results["loss"][-1] = float("nan")
- after_lr = lrfinder.suggestion()
+ before_lr = lr_finder.suggestion()
+ lr_finder.results["loss"][-1] = float("nan")
+ after_lr = lr_finder.suggestion()
assert before_lr is not None
assert after_lr is not None
@@ -292,32 +301,10 @@ def configure_optimizers(self):
def test_lr_finder_fails_fast_on_bad_config(tmpdir):
"""Test that tune fails if the model does not have a lr BEFORE running lr find."""
- trainer = Trainer(default_root_dir=tmpdir, max_steps=2, auto_lr_find=True)
- with pytest.raises(MisconfigurationException, match="should have one of these fields"):
- trainer.tune(BoringModel())
-
-
-def test_lr_find_with_bs_scale(tmpdir):
- """Test that lr_find runs with batch_size_scaling."""
- seed_everything(1)
-
- class BoringModelTune(BoringModel):
- def __init__(self, learning_rate=0.1, batch_size=2):
- super().__init__()
- self.save_hyperparameters()
-
- model = BoringModelTune()
- before_lr = model.hparams.learning_rate
-
- # logger file to get meta
- trainer = Trainer(default_root_dir=tmpdir, max_epochs=3, auto_lr_find=True, auto_scale_batch_size=True)
- result = trainer.tune(model)
- bs = result["scale_batch_size"]
- after_lr = result["lr_find"].suggestion()
-
- assert after_lr is not None
- assert after_lr != before_lr
- assert isinstance(bs, int)
+ trainer = Trainer(default_root_dir=tmpdir, max_steps=2)
+ tuner = Tuner(trainer)
+ with pytest.raises(AttributeError, match="should have one of these fields"):
+ tuner.lr_find(BoringModel(), update_attr=True)
def test_lr_candidates_between_min_and_max(tmpdir):
@@ -334,7 +321,8 @@ def __init__(self, learning_rate=0.1):
lr_min = 1e-8
lr_max = 1.0
- lr_finder = trainer.tuner.lr_find(model, max_lr=lr_min, min_lr=lr_max, num_training=3)
+ tuner = Tuner(trainer)
+ lr_finder = tuner.lr_find(model, max_lr=lr_min, min_lr=lr_max, num_training=3)
lr_candidates = lr_finder.results["lr"]
assert all(lr_min <= lr <= lr_max for lr in lr_candidates)
@@ -353,14 +341,16 @@ def training_step_end(self, outputs):
model = TestModel()
trainer = Trainer(default_root_dir=tmpdir)
+ tuner = Tuner(trainer)
num_training = 3
- trainer.tuner.lr_find(model=model, num_training=num_training)
+ tuner.lr_find(model=model, num_training=num_training)
def test_multiple_lr_find_calls_gives_same_results(tmpdir):
"""Tests that lr_finder gives same results if called multiple times."""
seed_everything(1)
model = BoringModel()
+ model.lr = 0.1
trainer = Trainer(
default_root_dir=tmpdir,
@@ -371,7 +361,8 @@ def test_multiple_lr_find_calls_gives_same_results(tmpdir):
enable_model_summary=False,
enable_checkpointing=False,
)
- all_res = [trainer.tuner.lr_find(model).results for _ in range(3)]
+ tuner = Tuner(trainer)
+ all_res = [tuner.lr_find(model).results for _ in range(3)]
assert all(
all_res[0][k] == curr_lr_finder[k] and len(curr_lr_finder[k]) > 10
@@ -427,21 +418,12 @@ def __init__(self):
model = TestModel()
trainer = Trainer(default_root_dir=tmpdir)
- lr_finder = trainer.tuner.lr_find(model=model, update_attr=True, num_training=1) # force insufficient data points
+ tuner = Tuner(trainer)
+ lr_finder = tuner.lr_find(model=model, update_attr=True, num_training=1) # force insufficient data points
assert lr_finder.suggestion() is None
assert model.learning_rate == 0.123 # must remain unchanged because suggestion is not possible
-def test_if_lr_finder_callback_already_configured():
- """Test that an error is raised if `LearningRateFinder` is already configured inside `Tuner`"""
- cb = LearningRateFinder()
- trainer = Trainer(auto_lr_find=True, callbacks=cb)
- model = BoringModel()
-
- with pytest.raises(MisconfigurationException, match="Trainer is already configured with a `LearningRateFinder`"):
- trainer.tune(model)
-
-
def test_lr_finder_callback_restarting(tmpdir):
"""Test that `LearningRateFinder` does not set restarting=True when loading checkpoint."""
@@ -506,7 +488,8 @@ def test_lr_finder_with_ddp(tmpdir):
accelerator="cpu",
)
- trainer.tuner.lr_find(model, datamodule=dm, update_attr=True, num_training=20)
+ tuner = Tuner(trainer)
+ tuner.lr_find(model, datamodule=dm, update_attr=True, num_training=20)
lr = trainer.lightning_module.lr
lr = trainer.strategy.broadcast(lr)
assert trainer.lightning_module.lr == lr
diff --git a/tests/tests_pytorch/tuner/test_scale_batch_size.py b/tests/tests_pytorch/tuner/test_scale_batch_size.py
index 5f6e1ec035b36..2dec52cd8cf8c 100644
--- a/tests/tests_pytorch/tuner/test_scale_batch_size.py
+++ b/tests/tests_pytorch/tuner/test_scale_batch_size.py
@@ -23,6 +23,7 @@
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks.batch_size_finder import BatchSizeFinder
from pytorch_lightning.demos.boring_classes import BoringDataModule, BoringModel, RandomDataset
+from pytorch_lightning.tuner.tuning import Tuner
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests_pytorch.helpers.runif import RunIf
@@ -64,9 +65,8 @@ def test_scale_batch_size_method_with_model_or_datamodule(tmpdir, model_bs, dm_b
model = BatchSizeModel(model_bs)
datamodule = BatchSizeDataModule(dm_bs) if dm_bs != -1 else None
- new_batch_size = trainer.tuner.scale_batch_size(
- model, mode="binsearch", init_val=4, max_trials=2, datamodule=datamodule
- )
+ tuner = Tuner(trainer)
+ new_batch_size = tuner.scale_batch_size(model, mode="binsearch", init_val=4, max_trials=2, datamodule=datamodule)
assert new_batch_size == 16
if model_bs is not None:
@@ -87,6 +87,7 @@ def test_trainer_reset_correctly(tmpdir, trainer_fn):
# logger file to get meta
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)
+ tuner = Tuner(trainer)
changed_attributes = [
"loggers",
@@ -103,7 +104,7 @@ def test_trainer_reset_correctly(tmpdir, trainer_fn):
expected_loop_state_dict = trainer.fit_loop.state_dict()
with no_warning_call(UserWarning, match="Please add the following callbacks"):
- trainer.tuner.scale_batch_size(model, max_trials=64, method=trainer_fn)
+ tuner.scale_batch_size(model, max_trials=64, method=trainer_fn)
actual = {ca: getattr(trainer, ca) for ca in changed_attributes}
actual_loop_state_dict = trainer.fit_loop.state_dict()
@@ -125,10 +126,9 @@ def test_auto_scale_batch_size_trainer_arg(tmpdir, scale_arg):
"""Test possible values for 'batch size auto scaling' Trainer argument."""
before_batch_size = 2
model = BatchSizeModel(batch_size=before_batch_size)
- trainer = Trainer(
- default_root_dir=tmpdir, max_epochs=1, auto_scale_batch_size=scale_arg, accelerator="gpu", devices=1
- )
- trainer.tune(model)
+ trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, accelerator="gpu", devices=1)
+ tuner = Tuner(trainer)
+ tuner.scale_batch_size(model)
after_batch_size = model.batch_size
assert before_batch_size != after_batch_size, "Batch size was not altered after running auto scaling of batch size"
@@ -155,8 +155,9 @@ def val_dataloader(self):
model_class = HparamsBatchSizeModel if use_hparams else BatchSizeModel
model = model_class(**hparams)
- trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, auto_scale_batch_size=True)
- trainer.tune(model, scale_batch_size_kwargs={"steps_per_trial": 2, "max_trials": 4})
+ trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)
+ tuner = Tuner(trainer)
+ tuner.scale_batch_size(model, steps_per_trial=2, max_trials=4)
after_batch_size = model.hparams.batch_size if use_hparams else model.batch_size
assert before_batch_size != after_batch_size
assert after_batch_size <= len(trainer.train_dataloader.dataset)
@@ -183,8 +184,9 @@ def val_dataloader(self):
datamodule = datamodule_class(batch_size=before_batch_size)
model = BatchSizeModel(**hparams)
- trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, auto_scale_batch_size=True)
- trainer.tune(model, datamodule=datamodule, scale_batch_size_kwargs={"steps_per_trial": 2, "max_trials": 4})
+ trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)
+ tuner = Tuner(trainer)
+ tuner.scale_batch_size(model, datamodule=datamodule, steps_per_trial=2, max_trials=4)
after_batch_size = datamodule.hparams.batch_size if use_hparams else datamodule.batch_size
assert trainer.datamodule == datamodule
assert before_batch_size < after_batch_size
@@ -202,10 +204,12 @@ def __init__(self, batch_size=1):
self.save_hyperparameters()
model = TestModel()
- trainer = Trainer(default_root_dir=tmpdir, max_steps=1, max_epochs=1000, auto_scale_batch_size=True)
+ trainer = Trainer(default_root_dir=tmpdir, max_steps=1, max_epochs=1000)
+ tuner = Tuner(trainer)
+
expected_message = "Field `model.batch_size` and `model.hparams.batch_size` are mutually exclusive!"
with pytest.warns(UserWarning, match=expected_message):
- trainer.tune(model)
+ tuner.scale_batch_size(model)
@pytest.mark.parametrize("scale_method", ["power", "binsearch"])
@@ -216,8 +220,9 @@ def test_call_to_trainer_method(tmpdir, scale_method):
# logger file to get meta
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)
+ tuner = Tuner(trainer)
- after_batch_size = trainer.tuner.scale_batch_size(model, mode=scale_method, max_trials=5)
+ after_batch_size = tuner.scale_batch_size(model, mode=scale_method, max_trials=5)
model.batch_size = after_batch_size
trainer.fit(model)
@@ -225,7 +230,7 @@ def test_call_to_trainer_method(tmpdir, scale_method):
def test_error_on_dataloader_passed_to_fit(tmpdir):
- """Verify that when the auto scale batch size feature raises an error if a train dataloader is passed to
+ """Verify that when the auto-scale batch size feature raises an error if a train dataloader is passed to
fit."""
# only train passed to fit
@@ -235,25 +240,23 @@ def test_error_on_dataloader_passed_to_fit(tmpdir):
max_epochs=1,
limit_val_batches=0.1,
limit_train_batches=0.2,
- auto_scale_batch_size="power",
)
- fit_options = dict(train_dataloaders=model.train_dataloader())
+ tuner = Tuner(trainer)
with pytest.raises(
MisconfigurationException,
match="Batch size finder cannot be used with dataloaders passed directly",
):
- trainer.tune(model, **fit_options)
+ tuner.scale_batch_size(model, train_dataloaders=model.train_dataloader(), mode="power")
@RunIf(min_cuda_gpus=1)
def test_auto_scale_batch_size_with_amp(tmpdir):
before_batch_size = 2
model = BatchSizeModel(batch_size=before_batch_size)
- trainer = Trainer(
- default_root_dir=tmpdir, max_steps=1, auto_scale_batch_size=True, accelerator="gpu", devices=1, precision=16
- )
- trainer.tune(model)
+ trainer = Trainer(default_root_dir=tmpdir, max_steps=1, accelerator="gpu", devices=1, precision=16)
+ tuner = Tuner(trainer)
+ tuner.scale_batch_size(model)
after_batch_size = model.batch_size
assert trainer.scaler is not None
assert after_batch_size != before_batch_size
@@ -261,11 +264,10 @@ def test_auto_scale_batch_size_with_amp(tmpdir):
def test_scale_batch_size_no_trials(tmpdir):
"""Check the result is correct even when no trials are run."""
- trainer = Trainer(
- default_root_dir=tmpdir, max_epochs=1, limit_val_batches=1, limit_train_batches=1, auto_scale_batch_size="power"
- )
+ trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_val_batches=1, limit_train_batches=1)
+ tuner = Tuner(trainer)
model = BatchSizeModel(batch_size=2)
- result = trainer.tuner.scale_batch_size(model, max_trials=0)
+ result = tuner.scale_batch_size(model, max_trials=0, mode="power")
assert result == 2
@@ -283,13 +285,11 @@ def __init__(self):
max_epochs=1,
limit_val_batches=1,
limit_train_batches=1,
- auto_scale_batch_size="ThisModeDoesNotExist",
)
+ tuner = Tuner(trainer)
with pytest.raises(ValueError, match="should be either of"):
- trainer.tune(model)
- with pytest.raises(ValueError, match="should be either of"):
- trainer.tuner.scale_batch_size(model, mode="ThisModeDoesNotExist")
+ tuner.scale_batch_size(model, mode="ThisModeDoesNotExist")
@pytest.mark.parametrize("scale_method", ["power", "binsearch"])
@@ -305,9 +305,11 @@ def test_dataloader_reset_with_scale_batch_size(tmpdir, scale_method):
"mode": scale_method,
}
- trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, auto_scale_batch_size=True)
+ trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)
+ tuner = Tuner(trainer)
+
with patch.object(model, "on_train_epoch_end") as advance_mocked:
- new_batch_size = trainer.tune(model, scale_batch_size_kwargs=scale_batch_size_kwargs)["scale_batch_size"]
+ new_batch_size = tuner.scale_batch_size(model, **scale_batch_size_kwargs)
assert advance_mocked.call_count == max_trials
assert trainer.train_dataloader.loaders.batch_size == new_batch_size
@@ -322,10 +324,9 @@ def test_tuner_with_evaluation_methods(tmpdir, trainer_fn):
expected_scaled_batch_size = before_batch_size ** (max_trials + 1)
model = BatchSizeModel(batch_size=before_batch_size)
- trainer = Trainer(default_root_dir=tmpdir, max_epochs=100, auto_scale_batch_size=True)
- trainer.tune(
- model, scale_batch_size_kwargs={"max_trials": max_trials, "batch_arg_name": "batch_size"}, method=trainer_fn
- )
+ trainer = Trainer(default_root_dir=tmpdir, max_epochs=100)
+ tuner = Tuner(trainer)
+ tuner.scale_batch_size(model, max_trials=max_trials, batch_arg_name="batch_size", method=trainer_fn)
after_batch_size = model.batch_size
loop = getattr(trainer, f"{trainer_fn}_loop")
@@ -380,52 +381,45 @@ def test_batch_size_finder_callback(tmpdir, trainer_fn):
def test_invalid_method_in_tuner():
"""Test that an invalid value for `method` raises an error in `Tuner`"""
- trainer = Trainer(auto_scale_batch_size=True)
+ trainer = Trainer()
+ tuner = Tuner(trainer)
model = BoringModel()
with pytest.raises(ValueError, match="method .* is invalid."):
- trainer.tune(model, method="prediction")
-
-
-def test_if_batch_size_finder_callback_already_configured():
- """Test that an error is raised if BatchSizeFinder is already configured inside `Tuner`"""
- cb = BatchSizeFinder()
- trainer = Trainer(auto_scale_batch_size=True, callbacks=cb)
- model = BoringModel()
-
- with pytest.raises(MisconfigurationException, match="Trainer is already configured with a `BatchSizeFinder`"):
- trainer.tune(model)
+ tuner.scale_batch_size(model, method="prediction")
def test_error_if_train_or_val_dataloaders_passed_with_eval_method():
"""Test that an error is raised if `train_dataloaders` or `val_dataloaders` is passed with eval method inside
`Tuner`"""
- trainer = Trainer(auto_scale_batch_size=True)
+ trainer = Trainer()
+ tuner = Tuner(trainer)
model = BoringModel()
dl = model.train_dataloader()
with pytest.raises(MisconfigurationException, match="please consider setting `dataloaders` instead"):
- trainer.tune(model, train_dataloaders=dl, method="validate")
+ tuner.scale_batch_size(model, train_dataloaders=dl, method="validate")
with pytest.raises(MisconfigurationException, match="please consider setting `dataloaders` instead"):
- trainer.tune(model, val_dataloaders=dl, method="validate")
+ tuner.scale_batch_size(model, val_dataloaders=dl, method="validate")
def test_error_if_dataloaders_passed_with_fit_method():
"""Test that an error is raised if `dataloaders` is passed with fit method inside `Tuner`"""
- trainer = Trainer(auto_scale_batch_size=True)
+ trainer = Trainer()
+ tuner = Tuner(trainer)
model = BoringModel()
dl = model.val_dataloader()
with pytest.raises(
MisconfigurationException, match="please consider setting `train_dataloaders` and `val_dataloaders` instead"
):
- trainer.tune(model, dataloaders=dl, method="fit")
+ tuner.scale_batch_size(model, dataloaders=dl, method="fit")
def test_batch_size_finder_with_distributed_strategies():
"""Test that an error is raised when batch size finder is used with multi-device strategy."""
- trainer = Trainer(auto_scale_batch_size=True, devices=2, strategy="ddp", accelerator="cpu")
+ trainer = Trainer(devices=2, strategy="ddp", accelerator="cpu")
model = BoringModel()
bs_finder = BatchSizeFinder()
@@ -442,13 +436,14 @@ class CustomModel(BoringModel):
def val_dataloader(self):
return [super().val_dataloader(), super().val_dataloader()]
- trainer = Trainer(auto_scale_batch_size=True)
+ trainer = Trainer()
+ tuner = Tuner(trainer)
model = CustomModel()
with pytest.raises(
MisconfigurationException, match="Batch size finder cannot be used with multiple .* dataloaders"
):
- trainer.tune(model, method="validate")
+ tuner.scale_batch_size(model, method="validate")
@pytest.mark.parametrize("scale_method, expected_batch_size", [("power", 62), ("binsearch", 100)])
@@ -467,8 +462,9 @@ def train_dataloader(self):
model.training_epoch_end = None
scale_batch_size_kwargs = {"max_trials": 10, "steps_per_trial": 1, "init_val": 500, "mode": scale_method}
- trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, auto_scale_batch_size=True)
- new_batch_size = trainer.tune(model, scale_batch_size_kwargs=scale_batch_size_kwargs)["scale_batch_size"]
+ trainer = Trainer(default_root_dir=tmpdir, max_epochs=2)
+ tuner = Tuner(trainer)
+ new_batch_size = tuner.scale_batch_size(model, **scale_batch_size_kwargs)
assert new_batch_size == model.batch_size
assert new_batch_size == expected_batch_size
assert trainer.train_dataloader.loaders.batch_size == expected_batch_size
diff --git a/tests/tests_pytorch/tuner/test_tuning.py b/tests/tests_pytorch/tuner/test_tuning.py
index 1fd3cd33afeab..c49f0874cfe7b 100644
--- a/tests/tests_pytorch/tuner/test_tuning.py
+++ b/tests/tests_pytorch/tuner/test_tuning.py
@@ -16,47 +16,34 @@
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import BatchSizeFinder, LearningRateFinder
from pytorch_lightning.demos.boring_classes import BoringModel
-from pytorch_lightning.utilities.exceptions import MisconfigurationException
+from pytorch_lightning.tuner.tuning import Tuner
def test_tuner_with_distributed_strategies():
"""Test that an error is raised when tuner is used with multi-device strategy."""
- trainer = Trainer(auto_scale_batch_size=True, devices=2, strategy="ddp", accelerator="cpu")
+ trainer = Trainer(devices=2, strategy="ddp", accelerator="cpu")
+ tuner = Tuner(trainer)
model = BoringModel()
- with pytest.raises(MisconfigurationException, match=r"not supported with `Trainer\(strategy='ddp'\)`"):
- trainer.tune(model)
+ with pytest.raises(ValueError, match=r"not supported with `Trainer\(strategy='ddp'\)`"):
+ tuner.scale_batch_size(model)
def test_tuner_with_already_configured_batch_size_finder():
"""Test that an error is raised when tuner is already configured with BatchSizeFinder."""
- trainer = Trainer(auto_scale_batch_size=True, callbacks=[BatchSizeFinder()])
+ trainer = Trainer(callbacks=[BatchSizeFinder()])
+ tuner = Tuner(trainer)
model = BoringModel()
- with pytest.raises(MisconfigurationException, match=r"Trainer is already configured with a `BatchSizeFinder`"):
- trainer.tune(model)
+ with pytest.raises(ValueError, match=r"Trainer is already configured with a `BatchSizeFinder`"):
+ tuner.scale_batch_size(model)
def test_tuner_with_already_configured_learning_rate_finder():
"""Test that an error is raised when tuner is already configured with LearningRateFinder."""
- trainer = trainer = Trainer(auto_lr_find=True, callbacks=[LearningRateFinder()])
+ trainer = Trainer(callbacks=[LearningRateFinder()])
+ tuner = Tuner(trainer)
model = BoringModel()
- with pytest.raises(MisconfigurationException, match=r"Trainer is already configured with a `LearningRateFinder`"):
- trainer.tune(model)
-
-
-def test_tuner_with_already_configured_learning_rate_finder_and_batch_size_finder():
- """Test that an error is raised when tuner are already configured with LearningRateFinder and
- BatchSizeFinder."""
- trainer = trainer = Trainer(
- auto_lr_find=True, auto_scale_batch_size=True, callbacks=[LearningRateFinder(), BatchSizeFinder()]
- )
- model = BoringModel()
-
- with pytest.raises(
- MisconfigurationException,
- match=r"Trainer is already configured with `LearningRateFinder` and "
- r"`BatchSizeFinder` callbacks. Please remove them if you want to use the Tuner.",
- ):
- trainer.tune(model)
+ with pytest.raises(ValueError, match=r"Trainer is already configured with a `LearningRateFinder`"):
+ tuner.lr_find(model)