Skip to content

Decouple Tuner from Trainer #16462

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 40 commits into from
Jan 27, 2023
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
427f428
removal
awaelchli Jan 21, 2023
ce1da14
delete
awaelchli Jan 21, 2023
6568581
remove
awaelchli Jan 21, 2023
e864e92
api docs
awaelchli Jan 21, 2023
ce3ac54
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 21, 2023
ced78cc
attr_name
awaelchli Jan 22, 2023
2799a23
Merge branch 'master' into removal/tuner
awaelchli Jan 26, 2023
86cd0e4
tests
awaelchli Jan 26, 2023
ffb2bfa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 26, 2023
42793d6
revert
awaelchli Jan 26, 2023
e9cb7a9
Merge remote-tracking branch 'origin/removal/tuner' into removal/tuner
awaelchli Jan 26, 2023
ebe6795
checks
awaelchli Jan 26, 2023
166f120
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 26, 2023
3f6a890
test tuning
awaelchli Jan 26, 2023
c4360b8
Merge remote-tracking branch 'origin/removal/tuner' into removal/tuner
awaelchli Jan 26, 2023
d6761e3
fixes
awaelchli Jan 26, 2023
37b8321
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 26, 2023
40ebcc6
fixes
awaelchli Jan 27, 2023
bdb6ae3
Merge remote-tracking branch 'origin/removal/tuner' into removal/tuner
awaelchli Jan 27, 2023
54d13fb
refactor
awaelchli Jan 27, 2023
32af540
docstring
awaelchli Jan 27, 2023
549c988
tests
awaelchli Jan 27, 2023
718d6ae
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 27, 2023
5a942e4
docs
awaelchli Jan 27, 2023
6431f4e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 27, 2023
ffdf205
remove setter
awaelchli Jan 27, 2023
7dc7e6c
changelog and defaults
awaelchli Jan 27, 2023
5151c2f
Merge remote-tracking branch 'origin/removal/tuner' into removal/tuner
awaelchli Jan 27, 2023
a22d0a9
types
awaelchli Jan 27, 2023
ccb56ae
chlog
awaelchli Jan 27, 2023
0d87706
resolve circular import
awaelchli Jan 27, 2023
c023d4a
Merge branch 'master' into removal/tuner
carmocca Jan 27, 2023
3f62661
Update src/pytorch_lightning/tuner/lr_finder.py
awaelchli Jan 27, 2023
ec6cbdf
remove resolved todo for circular import
awaelchli Jan 27, 2023
1665bd6
Merge remote-tracking branch 'origin/removal/tuner' into removal/tuner
awaelchli Jan 27, 2023
3298ea0
Merge branch 'master' into removal/tuner
carmocca Jan 27, 2023
2a449ac
pre-commit
carmocca Jan 27, 2023
2218060
Remove stale TODO
carmocca Jan 27, 2023
8687582
Update src/pytorch_lightning/CHANGELOG.md
awaelchli Jan 27, 2023
68fbbfd
Merge branch 'master' into removal/tuner
awaelchli Jan 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 3 additions & 72 deletions docs/source-pytorch/common/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -287,69 +287,6 @@ Example::
# no accumulation for epochs 1-4. accumulate 3 for epochs 5-10. accumulate 20 after that
trainer = Trainer(accumulate_grad_batches={5: 3, 10: 20})

auto_scale_batch_size
^^^^^^^^^^^^^^^^^^^^^

.. raw:: html

<video width="50%" max-width="400px" controls
poster="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/thumb/auto_scale%E2%80%A8_batch_size.jpg"
src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/auto_scale_batch_size.mp4"></video>

|

Automatically tries to find the largest batch size that fits into memory,
before any training.

.. code-block:: python

# default used by the Trainer (no scaling of batch size)
trainer = Trainer(auto_scale_batch_size=None)

# run batch size scaling, result overrides hparams.batch_size
trainer = Trainer(auto_scale_batch_size="binsearch")

# call tune to find the batch size
trainer.tune(model)


auto_lr_find
^^^^^^^^^^^^

.. raw:: html

<video width="50%" max-width="400px" controls
poster="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/thumb/auto_lr_find.jpg"
src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/auto_lr_find.mp4"></video>

|

Runs a learning rate finder algorithm (see this `paper <https://arxiv.org/abs/1506.01186>`_)
when calling trainer.tune(), to find optimal initial learning rate.

.. code-block:: python

# default used by the Trainer (no learning rate finder)
trainer = Trainer(auto_lr_find=False)

Example::

# run learning rate finder, results override hparams.learning_rate
trainer = Trainer(auto_lr_find=True)

# call tune to find the lr
trainer.tune(model)

Example::

# run learning rate finder, results override hparams.my_lr_arg
trainer = Trainer(auto_lr_find='my_lr_arg')

# call tune to find the lr
trainer.tune(model)

.. note::
See the :ref:`learning rate finder guide <learning_rate_finder>`.

benchmark
^^^^^^^^^
Expand Down Expand Up @@ -617,7 +554,7 @@ impact to subsequent runs. These are the changes enabled:
- The :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` callbacks will not trigger.
- The :class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping` callbacks will not trigger.
- Sets ``limit_{train,val,test,predict}_batches`` to 1 or the number passed.
- Disables the Tuner.
- Disables the tuning callbacks (:class:`~pytorch_lightning.callbacks.batch_size_finder.BatchSizeFinder`, :class:`~pytorch_lightning.callbacks.lr_finder.LearningRateFinder`).
- If using the CLI, the configuration file is not saved.


Expand Down Expand Up @@ -1358,12 +1295,6 @@ predict
.. automethod:: pytorch_lightning.trainer.Trainer.predict
:noindex:

tune
****

.. automethod:: pytorch_lightning.trainer.Trainer.tune
:noindex:


Properties
^^^^^^^^^^
Expand Down Expand Up @@ -1523,11 +1454,11 @@ execution within that function, and the status of the Trainer.

.. code-block:: python

# fn in ("fit", "validate", "test", "predict", "tune")
# fn in ("fit", "validate", "test", "predict")
trainer.state.fn
# status in ("initializing", "running", "finished", "interrupted")
trainer.state.status
# stage in ("train", "sanity_check", "validate", "test", "predict", "tune")
# stage in ("train", "sanity_check", "validate", "test", "predict")
trainer.state.stage

should_stop
Expand Down
11 changes: 5 additions & 6 deletions src/pytorch_lightning/callbacks/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,8 @@ class LearningRateFinder(Callback):

Args:
min_lr: Minimum learning rate to investigate

max_lr: Maximum learning rate to investigate

num_training_steps: Number of learning rates to test

mode: Search strategy to update learning rate after each batch:

- ``'exponential'`` (default): Increases the learning rate exponentially.
Expand All @@ -45,7 +42,6 @@ class LearningRateFinder(Callback):
early_stop_threshold: Threshold for stopping the search. If the
loss at any point is larger than early_stop_threshold*best_loss
then the search is stopped. To disable, set to None.

update_attr: Whether to update the learning rate attribute or not.

Example::
Expand Down Expand Up @@ -73,8 +69,8 @@ def on_train_epoch_start(self, trainer, pl_module):

Raises:
MisconfigurationException:
If learning rate/lr in ``model`` or ``model.hparams`` isn't overridden when ``auto_lr_find=True``,
or if you are using more than one optimizer.
If learning rate/lr in ``model`` or ``model.hparams`` isn't overridden, or if you are using more than
one optimizer.
"""

SUPPORTED_MODES = ("linear", "exponential")
Expand All @@ -87,6 +83,7 @@ def __init__(
mode: str = "exponential",
early_stop_threshold: Optional[float] = 4.0,
update_attr: bool = False,
attr_name: str = ""
) -> None:
mode = mode.lower()
if mode not in self.SUPPORTED_MODES:
Expand All @@ -98,6 +95,7 @@ def __init__(
self._mode = mode
self._early_stop_threshold = early_stop_threshold
self._update_attr = update_attr
self._attr_name = attr_name

self._early_exit = False
self.lr_finder: Optional[_LRFinder] = None
Expand All @@ -113,6 +111,7 @@ def lr_find(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> Non
mode=self._mode,
early_stop_threshold=self._early_stop_threshold,
update_attr=self._update_attr,
attr_name=self._attr_name,
)

if self._early_exit:
Expand Down
1 change: 0 additions & 1 deletion src/pytorch_lightning/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,6 @@ def subcommands() -> Dict[str, Set[str]]:
"validate": {"model", "dataloaders", "datamodule"},
"test": {"model", "dataloaders", "datamodule"},
"predict": {"model", "dataloaders", "datamodule"},
"tune": {"model", "train_dataloaders", "val_dataloaders", "datamodule"},
}

def _add_subcommands(self, parser: LightningArgumentParser, **kwargs: Any) -> None:
Expand Down
91 changes: 2 additions & 89 deletions src/pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
from torch import Tensor
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from typing_extensions import Literal

import pytorch_lightning as pl
from lightning_fabric.utilities.apply_func import convert_tensors_to_scalars
Expand Down Expand Up @@ -77,7 +76,6 @@
from pytorch_lightning.trainer.connectors.signal_connector import SignalConnector
from pytorch_lightning.trainer.states import RunningStage, TrainerFn, TrainerState, TrainerStatus
from pytorch_lightning.trainer.supporters import CombinedLoader
from pytorch_lightning.tuner.tuning import _TunerResult, Tuner
from pytorch_lightning.utilities import GradClipAlgorithmType, parsing
from pytorch_lightning.utilities.argparse import (
_defaults_from_env_vars,
Expand Down Expand Up @@ -147,10 +145,8 @@ def __init__(
benchmark: Optional[bool] = None,
deterministic: Optional[Union[bool, _LITERAL_WARN]] = None,
reload_dataloaders_every_n_epochs: int = 0,
auto_lr_find: Union[bool, str] = False,
replace_sampler_ddp: bool = True,
detect_anomaly: bool = False,
auto_scale_batch_size: Union[str, bool] = False,
plugins: Optional[Union[PLUGIN_INPUT, List[PLUGIN_INPUT]]] = None,
multiple_trainloader_mode: str = "max_size_cycle",
inference_mode: bool = True,
Expand All @@ -166,31 +162,6 @@ def __init__(
accumulate_grad_batches: Accumulates grads every k batches or as set up in the dict.
Default: ``None``.

auto_lr_find: If set to True, will make trainer.tune() run a learning rate finder,
trying to optimize initial learning for faster convergence. trainer.tune() method will
set the suggested learning rate in self.lr or self.learning_rate in the LightningModule.
To use a different key set a string instead of True with the key name.
Default: ``False``.

auto_scale_batch_size: If set to True, will `initially` run a batch size
finder trying to find the largest batch size that fits into memory.
The result will be stored in self.batch_size in the LightningModule
or LightningDataModule depending on your setup.
Additionally, can be set to either `power` that estimates the batch size through
a power search or `binsearch` that estimates the batch size through a binary search.
Default: ``False``.

auto_select_gpus: If enabled and ``gpus`` or ``devices`` is an integer, pick available
gpus automatically. This is especially useful when
GPUs are configured to be in "exclusive mode", such
that only one process at a time can access them.
Default: ``False``.

.. deprecated:: v1.9
``auto_select_gpus`` has been deprecated in v1.9.0 and will be removed in v2.0.0.
Please use the function :func:`~lightning_fabric.accelerators.cuda.find_usable_cuda_devices`
instead.

benchmark: The value (``True`` or ``False``) to set ``torch.backends.cudnn.benchmark`` to.
The value for ``torch.backends.cudnn.benchmark`` set in the current session will be used
(``False`` if not manually set). If :paramref:`~pytorch_lightning.trainer.Trainer.deterministic` is set
Expand Down Expand Up @@ -364,7 +335,6 @@ def __init__(
self._callback_connector = CallbackConnector(self)
self._checkpoint_connector = CheckpointConnector(self)
self._signal_connector = SignalConnector(self)
self.tuner = Tuner(self)

# init loops
self.fit_loop = _FitLoop(min_epochs=min_epochs, max_epochs=max_epochs)
Expand Down Expand Up @@ -428,9 +398,6 @@ def __init__(
self._detect_anomaly: bool = detect_anomaly
self._setup_on_init()

# configure tuner
self.tuner.on_trainer_init(auto_lr_find, auto_scale_batch_size)

# configure profiler
setup._init_profiler(self, profiler)

Expand Down Expand Up @@ -834,60 +801,6 @@ def _predict_impl(

return results

def tune(
self,
model: "pl.LightningModule",
train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None,
val_dataloaders: Optional[EVAL_DATALOADERS] = None,
dataloaders: Optional[EVAL_DATALOADERS] = None,
datamodule: Optional[LightningDataModule] = None,
scale_batch_size_kwargs: Optional[Dict[str, Any]] = None,
lr_find_kwargs: Optional[Dict[str, Any]] = None,
method: Literal["fit", "validate", "test", "predict"] = "fit",
) -> _TunerResult:
r"""
Runs routines to tune hyperparameters before training.

Args:
model: Model to tune.

train_dataloaders: A collection of :class:`torch.utils.data.DataLoader` or a
:class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying training samples.
In the case of multiple dataloaders, please see this :ref:`section <multiple-dataloaders>`.

val_dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them specifying validation samples.

dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them specifying val/test/predict
samples used for running tuner on validation/testing/prediction.

datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`.

scale_batch_size_kwargs: Arguments for :func:`~pytorch_lightning.tuner.batch_size_scaling.scale_batch_size`

lr_find_kwargs: Arguments for :func:`~pytorch_lightning.tuner.lr_finder.lr_find`

method: Method to run tuner on. It can be any of ``("fit", "validate", "test", "predict")``.
"""
model = self._maybe_unwrap_optimized(model)
if not isinstance(model, pl.LightningModule):
raise TypeError(f"`Trainer.tune()` requires a `LightningModule`, got: {model.__class__.__qualname__}")

Trainer._log_api_event("tune")

with isolate_rng():
result = self.tuner._tune(
model,
train_dataloaders,
val_dataloaders,
dataloaders,
datamodule,
scale_batch_size_kwargs=scale_batch_size_kwargs,
lr_find_kwargs=lr_find_kwargs,
method=method,
)

return result

def _run(
self, model: "pl.LightningModule", ckpt_path: Optional[_PATH] = None
) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]:
Expand Down Expand Up @@ -1656,10 +1569,10 @@ def model(self) -> Optional[torch.nn.Module]:
"""
return self.strategy.model

# TODO: is this still needed
@model.setter
def model(self, model: torch.nn.Module) -> None:
"""Setter for the model, pass-through to accelerator and plugin where the model reference is stored. Used
by the Tuner to reset the state of Trainer and Accelerator.
"""Setter for the model, pass-through to accelerator and plugin where the model reference is stored.

Args:
model: The LightningModule, possibly wrapped into DataParallel or DistributedDataParallel, depending
Expand Down
26 changes: 26 additions & 0 deletions src/pytorch_lightning/tuner/batch_size_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,32 @@ def scale_batch_size(
max_trials: int = 25,
batch_arg_name: str = "batch_size",
) -> Optional[int]:
"""Iteratively try to find the largest batch size for a given model that does not give an out of memory (OOM)
error.

Args:
trainer: A Trainer instance.
model: Model to tune.
mode: Search strategy to update the batch size:

- ``'power'``: Keep multiplying the batch size by 2, until we get an OOM error.
- ``'binsearch'``: Initially keep multiplying by 2 and after encountering an OOM error
do a binary search between the last successful batch size and the batch size that failed.

steps_per_trial: number of steps to run with a given batch size.
Ideally 1 should be enough to test if an OOM error occurs,
however in practise a few are needed
init_val: initial batch size to start the search with
max_trials: max number of increases in batch size done before
algorithm is terminated
batch_arg_name: name of the attribute that stores the batch size.
It is expected that the user has provided a model or datamodule that has a hyperparameter
with that name. We will look for this attribute name in the following places

- ``model``
- ``model.hparams``
- ``trainer.datamodule`` (the datamodule passed to the tune method)
"""
if trainer.fast_dev_run:
rank_zero_warn("Skipping batch size scaler since `fast_dev_run` is enabled.")
return None
Expand Down
Loading