Skip to content

Decouple Tuner from Trainer #16462

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 40 commits into from
Jan 27, 2023
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
427f428
removal
awaelchli Jan 21, 2023
ce1da14
delete
awaelchli Jan 21, 2023
6568581
remove
awaelchli Jan 21, 2023
e864e92
api docs
awaelchli Jan 21, 2023
ce3ac54
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 21, 2023
ced78cc
attr_name
awaelchli Jan 22, 2023
2799a23
Merge branch 'master' into removal/tuner
awaelchli Jan 26, 2023
86cd0e4
tests
awaelchli Jan 26, 2023
ffb2bfa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 26, 2023
42793d6
revert
awaelchli Jan 26, 2023
e9cb7a9
Merge remote-tracking branch 'origin/removal/tuner' into removal/tuner
awaelchli Jan 26, 2023
ebe6795
checks
awaelchli Jan 26, 2023
166f120
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 26, 2023
3f6a890
test tuning
awaelchli Jan 26, 2023
c4360b8
Merge remote-tracking branch 'origin/removal/tuner' into removal/tuner
awaelchli Jan 26, 2023
d6761e3
fixes
awaelchli Jan 26, 2023
37b8321
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 26, 2023
40ebcc6
fixes
awaelchli Jan 27, 2023
bdb6ae3
Merge remote-tracking branch 'origin/removal/tuner' into removal/tuner
awaelchli Jan 27, 2023
54d13fb
refactor
awaelchli Jan 27, 2023
32af540
docstring
awaelchli Jan 27, 2023
549c988
tests
awaelchli Jan 27, 2023
718d6ae
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 27, 2023
5a942e4
docs
awaelchli Jan 27, 2023
6431f4e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 27, 2023
ffdf205
remove setter
awaelchli Jan 27, 2023
7dc7e6c
changelog and defaults
awaelchli Jan 27, 2023
5151c2f
Merge remote-tracking branch 'origin/removal/tuner' into removal/tuner
awaelchli Jan 27, 2023
a22d0a9
types
awaelchli Jan 27, 2023
ccb56ae
chlog
awaelchli Jan 27, 2023
0d87706
resolve circular import
awaelchli Jan 27, 2023
c023d4a
Merge branch 'master' into removal/tuner
carmocca Jan 27, 2023
3f62661
Update src/pytorch_lightning/tuner/lr_finder.py
awaelchli Jan 27, 2023
ec6cbdf
remove resolved todo for circular import
awaelchli Jan 27, 2023
1665bd6
Merge remote-tracking branch 'origin/removal/tuner' into removal/tuner
awaelchli Jan 27, 2023
3298ea0
Merge branch 'master' into removal/tuner
carmocca Jan 27, 2023
2a449ac
pre-commit
carmocca Jan 27, 2023
2218060
Remove stale TODO
carmocca Jan 27, 2023
8687582
Update src/pytorch_lightning/CHANGELOG.md
awaelchli Jan 27, 2023
68fbbfd
Merge branch 'master' into removal/tuner
awaelchli Jan 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 37 additions & 55 deletions docs/source-pytorch/advanced/training_tricks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -78,18 +78,24 @@ Auto-scaling of batch size can be enabled to find the largest batch size that fi
memory. Large batch size often yields a better estimation of the gradients, but may also result in
longer training time. Inspired by https://github.com/BlackHC/toma.

.. seealso:: :class:`~pytorch_lightning.trainer.trainer.Trainer`
.. seealso:: :class:`~pytorch_lightning.tuner.tuning.Tuner`

.. code-block:: python

# DEFAULT (ie: don't scale batch size automatically)
trainer = Trainer(auto_scale_batch_size=None)
from pytorch_lightning.tuner import Tuner

# Create a tuner for the trainer
trainer = Trainer(...)
tuner = Tuner(trainer)

# Autoscale batch size
trainer = Trainer(auto_scale_batch_size=None | "power" | "binsearch")
# Auto-scale batch size by growing it exponentially (default)
tuner.scale_batch_size(model, mode="power")

# Find the batch size
trainer.tune(model)
# Auto-scale batch size with binary search
tuner.scale_batch_size(model, mode="binsearch")

# Fit as normal with new batch size
trainer.fit(model)


Currently, this feature supports two modes ``'power'`` scaling and ``'binsearch'``
Expand Down Expand Up @@ -122,9 +128,10 @@ search for batch sizes larger than the size of the training dataset.
return DataLoader(train_dataset, batch_size=self.batch_size | self.hparams.batch_size)


trainer = Trainer(...)
model = LitModel(batch_size=32)
trainer.tune(model)
trainer = Trainer(...)
tuner = Tuner(trainer)
tuner.scale_batch_size(model)

# using LightningDataModule
class LitDataModule(LightningDataModule):
Expand All @@ -138,40 +145,19 @@ search for batch sizes larger than the size of the training dataset.
return DataLoader(train_dataset, batch_size=self.batch_size | self.hparams.batch_size)


trainer = Trainer(...)
model = MyModel()
datamodule = LitDataModule(batch_size=32)
trainer.tune(model, datamodule=datamodule)

trainer = Trainer(...)
tuner = Tuner(trainer)
tuner.scale_batch_size(model, datamodule=datamodule)

Note that the ``train_dataloader`` can be either part of
the ``LightningModule`` or ``LightningDataModule``
as shown above. If both the ``LightningModule``
and the ``LightningDataModule`` contain a ``train_dataloader``,
the ``LightningDataModule`` takes precedence.

.. warning::

Due to the constraints listed above, this features does *NOT* work when passing dataloaders directly
to ``.fit()``.

The scaling algorithm has a number of parameters that the user can control by
invoking the :meth:`~pytorch_lightning.tuner.tuning.Tuner.scale_batch_size` method:

.. code-block:: python

# Use default in trainer construction
trainer = Trainer()
tuner = Tuner(trainer)

# Invoke method
new_batch_size = tuner.scale_batch_size(model, *extra_parameters_here)

# Override old batch size (this is done automatically)
model.hparams.batch_size = new_batch_size

# Fit as normal
trainer.fit(model)

The algorithm in short works by:
1. Dumping the current state of the model and trainer
2. Iteratively until convergence or maximum number of tries ``max_trials`` (default 25) has been reached:
Expand Down Expand Up @@ -247,14 +233,6 @@ Customizing Batch Size Finder
Learning Rate Finder
********************

.. raw:: html

<video width="50%" max-width="400px" controls
poster="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/thumb/auto_lr_find.jpg"
src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/auto_lr_find.mp4"></video>

|

For training deep neural networks, selecting a good learning rate is essential
for both better performance and faster convergence. Even optimizers such as
:class:`~torch.optim.Adam` that are self-adjusting the learning rate can benefit from more optimal
Expand Down Expand Up @@ -284,16 +262,17 @@ Using Lightning's built-in LR finder

To enable the learning rate finder, your :doc:`lightning module <../common/lightning_module>` needs to
have a ``learning_rate`` or ``lr`` attribute (or as a field in your ``hparams`` i.e.
``hparams.learning_rate`` or ``hparams.lr``). Then, set ``Trainer(auto_lr_find=True)``
during trainer construction, and then call ``trainer.tune(model)`` to run the LR finder.
``hparams.learning_rate`` or ``hparams.lr``). Then, create the :class:`~pytorch_lightning.tuner.tuning.Tuner` via ``tuner = Tuner(trainer)``
and call ``tuner.lr_find(model)`` to run the LR finder.
The suggested ``learning_rate`` will be written to the console and will be automatically
set to your :doc:`lightning module <../common/lightning_module>`, which can be accessed
via ``self.learning_rate`` or ``self.lr``.

.. seealso:: :ref:`trainer.tune <common/trainer:tune>`.

.. code-block:: python

from pytorch_lightning.tuner import Tuner


class LitModel(LightningModule):
def __init__(self, learning_rate):
super().__init__()
Expand All @@ -305,36 +284,39 @@ via ``self.learning_rate`` or ``self.lr``.


model = LitModel()
trainer = Trainer(...)

# Create a Tuner
tuner = Tuner(trainer)

# finds learning rate automatically
# sets hparams.lr or hparams.learning_rate to that learning rate
trainer = Trainer(auto_lr_find=True)

trainer.tune(model)
tuner.lr_find(model)


If your model is using an arbitrary value instead of ``self.lr`` or ``self.learning_rate``, set that value as ``auto_lr_find``:
If your model is using an arbitrary value instead of ``self.lr`` or ``self.learning_rate``, set that value in ``lr_find``:

.. code-block:: python

model = LitModel()
trainer = Trainer(...)
tuner = Tuner(trainer)

# to set to your own hparams.my_value
trainer = Trainer(auto_lr_find="my_value")
tuner.lr_find(model, attr_name="my_value")

trainer.tune(model)

You can also inspect the results of the learning rate finder or just play around
with the parameters of the algorithm. This can be done by invoking the
:meth:`~pytorch_lightning.tuner.tuning.Tuner.lr_find` method. A typical example of this would look like:
with the parameters of the algorithm. A typical example of this would look like:

.. code-block:: python

model = MyModelClass(hparams)
trainer = Trainer()
tuner = Tuner(trainer)

# Run learning rate finder
lr_finder = trainer.tuner.lr_find(model)
lr_finder = tuner.lr_find(model)

# Results can be found in
print(lr_finder.results)
Expand Down
75 changes: 3 additions & 72 deletions docs/source-pytorch/common/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -287,69 +287,6 @@ Example::
# no accumulation for epochs 1-4. accumulate 3 for epochs 5-10. accumulate 20 after that
trainer = Trainer(accumulate_grad_batches={5: 3, 10: 20})

auto_scale_batch_size
^^^^^^^^^^^^^^^^^^^^^

.. raw:: html

<video width="50%" max-width="400px" controls
poster="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/thumb/auto_scale%E2%80%A8_batch_size.jpg"
src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/auto_scale_batch_size.mp4"></video>

|

Automatically tries to find the largest batch size that fits into memory,
before any training.

.. code-block:: python

# default used by the Trainer (no scaling of batch size)
trainer = Trainer(auto_scale_batch_size=None)

# run batch size scaling, result overrides hparams.batch_size
trainer = Trainer(auto_scale_batch_size="binsearch")

# call tune to find the batch size
trainer.tune(model)


auto_lr_find
^^^^^^^^^^^^

.. raw:: html

<video width="50%" max-width="400px" controls
poster="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/thumb/auto_lr_find.jpg"
src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/auto_lr_find.mp4"></video>

|

Runs a learning rate finder algorithm (see this `paper <https://arxiv.org/abs/1506.01186>`_)
when calling trainer.tune(), to find optimal initial learning rate.

.. code-block:: python

# default used by the Trainer (no learning rate finder)
trainer = Trainer(auto_lr_find=False)

Example::

# run learning rate finder, results override hparams.learning_rate
trainer = Trainer(auto_lr_find=True)

# call tune to find the lr
trainer.tune(model)

Example::

# run learning rate finder, results override hparams.my_lr_arg
trainer = Trainer(auto_lr_find='my_lr_arg')

# call tune to find the lr
trainer.tune(model)

.. note::
See the :ref:`learning rate finder guide <learning_rate_finder>`.

benchmark
^^^^^^^^^
Expand Down Expand Up @@ -617,7 +554,7 @@ impact to subsequent runs. These are the changes enabled:
- The :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` callbacks will not trigger.
- The :class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping` callbacks will not trigger.
- Sets ``limit_{train,val,test,predict}_batches`` to 1 or the number passed.
- Disables the Tuner.
- Disables the tuning callbacks (:class:`~pytorch_lightning.callbacks.batch_size_finder.BatchSizeFinder`, :class:`~pytorch_lightning.callbacks.lr_finder.LearningRateFinder`).
- If using the CLI, the configuration file is not saved.


Expand Down Expand Up @@ -1358,12 +1295,6 @@ predict
.. automethod:: pytorch_lightning.trainer.Trainer.predict
:noindex:

tune
****

.. automethod:: pytorch_lightning.trainer.Trainer.tune
:noindex:


Properties
^^^^^^^^^^
Expand Down Expand Up @@ -1523,11 +1454,11 @@ execution within that function, and the status of the Trainer.

.. code-block:: python

# fn in ("fit", "validate", "test", "predict", "tune")
# fn in ("fit", "validate", "test", "predict")
trainer.state.fn
# status in ("initializing", "running", "finished", "interrupted")
trainer.state.status
# stage in ("train", "sanity_check", "validate", "test", "predict", "tune")
# stage in ("train", "sanity_check", "validate", "test", "predict")
trainer.state.stage

should_stop
Expand Down
19 changes: 14 additions & 5 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added an argument `include_cuda` in `pytorch_lightning.utilities.seed.isolate_rng` to disable managing `torch.cuda`'s rng ([#16423](https://github.com/Lightning-AI/lightning/pull/16423))

- Added migration logic to warn about checkpoints with apex AMP state ([#16161](https://github.com/Lightning-AI/lightning/pull/16161))

- Added the `Trainer.ckpt_path = ...` setter to statefully set the checkpoint path to load. This can act as a replacement for the removed `Trainer(resume_from_checkpoint=...)` flag ([#16187](https://github.com/Lightning-AI/lightning/pull/16187))

- Added `Tuner.lr_find(attr_name=...)` to specify custom learning rate attribute names ([#16462](https://github.com/Lightning-AI/lightning/pull/16462))


### Changed

Expand All @@ -22,16 +28,14 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* `pytorch_lightning.plugins.precision.native_amp` is now `pytorch_lightning.plugins.precision.amp`
* `NativeSyncBatchNorm` is now `TorchSyncBatchNorm`

### Deprecated
- Changed the default of `LearningRateFinder(update_attr=...)` and `Tuner.lr_find(update_attr=...)` to `True` ([#16462](https://github.com/Lightning-AI/lightning/pull/16462))

-

### Deprecated

### Added
-

- Added migration logic to warn about checkpoints with apex AMP state ([#16161](https://github.com/Lightning-AI/lightning/pull/16161))

- Added the `Trainer.ckpt_path = ...` setter to statefully set the checkpoint path to load. This can act as a replacement for the removed `Trainer(resume_from_checkpoint=...)` flag ([#16187](https://github.com/Lightning-AI/lightning/pull/16187))

### Removed

Expand Down Expand Up @@ -138,9 +142,14 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Removed support for passing a dictionary value to `self.log()` ([#16389](https://github.com/Lightning-AI/lightning/pull/16389))

- Removed `Trainer.model` setter ([#16462](https://github.com/Lightning-AI/lightning/pull/16462))

- Tuner removal
* Removed the deprecated `trainer.tuning` property ([#16379](https://github.com/Lightning-AI/lightning/pull/16379))
* Removed the deprecated `TrainerFn.TUNING` and `RunningStage.TUNING` enums ([#16379](https://github.com/Lightning-AI/lightning/pull/16379))
* Removed `Trainer.tune()` in favor of `Tuner(trainer).{lr_find,scale_batch_size}` ([#16462](https://github.com/Lightning-AI/lightning/pull/16462))
* Removed `Trainer(auto_scale_batch_size=...)` in favor of `Tuner(trainer).scale_batch_size()` ([#16462](https://github.com/Lightning-AI/lightning/pull/16462))
* Removed `Trainer(auto_lr_find=...)` in favor of `Tuner(trainer).lr_find()` ([#16462](https://github.com/Lightning-AI/lightning/pull/16462))

### Fixed

Expand Down
4 changes: 2 additions & 2 deletions src/pytorch_lightning/callbacks/batch_size_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import pytorch_lightning as pl
from pytorch_lightning.callbacks.callback import Callback
from pytorch_lightning.tuner.batch_size_scaling import scale_batch_size
from pytorch_lightning.tuner.batch_size_scaling import _scale_batch_size
from pytorch_lightning.utilities.exceptions import _TunerExitException, MisconfigurationException
from pytorch_lightning.utilities.parsing import lightning_hasattr
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
Expand Down Expand Up @@ -165,7 +165,7 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: O
)

def scale_batch_size(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
new_size = scale_batch_size(
new_size = _scale_batch_size(
trainer,
pl_module,
self._mode,
Expand Down
Loading