Skip to content

Commit b216a11

Browse files
awaelchlicarmocca
andauthored
Decouple Tuner from Trainer (#16462)
Co-authored-by: Carlos Mocholí <[email protected]>
1 parent c8c4722 commit b216a11

22 files changed

+339
-595
lines changed

docs/source-pytorch/advanced/training_tricks.rst

Lines changed: 37 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -78,18 +78,24 @@ Auto-scaling of batch size can be enabled to find the largest batch size that fi
7878
memory. Large batch size often yields a better estimation of the gradients, but may also result in
7979
longer training time. Inspired by https://github.com/BlackHC/toma.
8080

81-
.. seealso:: :class:`~pytorch_lightning.trainer.trainer.Trainer`
81+
.. seealso:: :class:`~pytorch_lightning.tuner.tuning.Tuner`
8282

8383
.. code-block:: python
8484
85-
# DEFAULT (ie: don't scale batch size automatically)
86-
trainer = Trainer(auto_scale_batch_size=None)
85+
from pytorch_lightning.tuner import Tuner
86+
87+
# Create a tuner for the trainer
88+
trainer = Trainer(...)
89+
tuner = Tuner(trainer)
8790
88-
# Autoscale batch size
89-
trainer = Trainer(auto_scale_batch_size=None | "power" | "binsearch")
91+
# Auto-scale batch size by growing it exponentially (default)
92+
tuner.scale_batch_size(model, mode="power")
9093
91-
# Find the batch size
92-
trainer.tune(model)
94+
# Auto-scale batch size with binary search
95+
tuner.scale_batch_size(model, mode="binsearch")
96+
97+
# Fit as normal with new batch size
98+
trainer.fit(model)
9399
94100
95101
Currently, this feature supports two modes ``'power'`` scaling and ``'binsearch'``
@@ -122,9 +128,10 @@ search for batch sizes larger than the size of the training dataset.
122128
return DataLoader(train_dataset, batch_size=self.batch_size | self.hparams.batch_size)
123129
124130
125-
trainer = Trainer(...)
126131
model = LitModel(batch_size=32)
127-
trainer.tune(model)
132+
trainer = Trainer(...)
133+
tuner = Tuner(trainer)
134+
tuner.scale_batch_size(model)
128135
129136
# using LightningDataModule
130137
class LitDataModule(LightningDataModule):
@@ -138,40 +145,19 @@ search for batch sizes larger than the size of the training dataset.
138145
return DataLoader(train_dataset, batch_size=self.batch_size | self.hparams.batch_size)
139146
140147
141-
trainer = Trainer(...)
142148
model = MyModel()
143149
datamodule = LitDataModule(batch_size=32)
144-
trainer.tune(model, datamodule=datamodule)
150+
151+
trainer = Trainer(...)
152+
tuner = Tuner(trainer)
153+
tuner.scale_batch_size(model, datamodule=datamodule)
145154
146155
Note that the ``train_dataloader`` can be either part of
147156
the ``LightningModule`` or ``LightningDataModule``
148157
as shown above. If both the ``LightningModule``
149158
and the ``LightningDataModule`` contain a ``train_dataloader``,
150159
the ``LightningDataModule`` takes precedence.
151160

152-
.. warning::
153-
154-
Due to the constraints listed above, this features does *NOT* work when passing dataloaders directly
155-
to ``.fit()``.
156-
157-
The scaling algorithm has a number of parameters that the user can control by
158-
invoking the :meth:`~pytorch_lightning.tuner.tuning.Tuner.scale_batch_size` method:
159-
160-
.. code-block:: python
161-
162-
# Use default in trainer construction
163-
trainer = Trainer()
164-
tuner = Tuner(trainer)
165-
166-
# Invoke method
167-
new_batch_size = tuner.scale_batch_size(model, *extra_parameters_here)
168-
169-
# Override old batch size (this is done automatically)
170-
model.hparams.batch_size = new_batch_size
171-
172-
# Fit as normal
173-
trainer.fit(model)
174-
175161
The algorithm in short works by:
176162
1. Dumping the current state of the model and trainer
177163
2. Iteratively until convergence or maximum number of tries ``max_trials`` (default 25) has been reached:
@@ -247,14 +233,6 @@ Customizing Batch Size Finder
247233
Learning Rate Finder
248234
********************
249235

250-
.. raw:: html
251-
252-
<video width="50%" max-width="400px" controls
253-
poster="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/thumb/auto_lr_find.jpg"
254-
src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/auto_lr_find.mp4"></video>
255-
256-
|
257-
258236
For training deep neural networks, selecting a good learning rate is essential
259237
for both better performance and faster convergence. Even optimizers such as
260238
:class:`~torch.optim.Adam` that are self-adjusting the learning rate can benefit from more optimal
@@ -284,16 +262,17 @@ Using Lightning's built-in LR finder
284262

285263
To enable the learning rate finder, your :doc:`lightning module <../common/lightning_module>` needs to
286264
have a ``learning_rate`` or ``lr`` attribute (or as a field in your ``hparams`` i.e.
287-
``hparams.learning_rate`` or ``hparams.lr``). Then, set ``Trainer(auto_lr_find=True)``
288-
during trainer construction, and then call ``trainer.tune(model)`` to run the LR finder.
265+
``hparams.learning_rate`` or ``hparams.lr``). Then, create the :class:`~pytorch_lightning.tuner.tuning.Tuner` via ``tuner = Tuner(trainer)``
266+
and call ``tuner.lr_find(model)`` to run the LR finder.
289267
The suggested ``learning_rate`` will be written to the console and will be automatically
290268
set to your :doc:`lightning module <../common/lightning_module>`, which can be accessed
291269
via ``self.learning_rate`` or ``self.lr``.
292270

293-
.. seealso:: :ref:`trainer.tune <common/trainer:tune>`.
294-
295271
.. code-block:: python
296272
273+
from pytorch_lightning.tuner import Tuner
274+
275+
297276
class LitModel(LightningModule):
298277
def __init__(self, learning_rate):
299278
super().__init__()
@@ -305,36 +284,39 @@ via ``self.learning_rate`` or ``self.lr``.
305284
306285
307286
model = LitModel()
287+
trainer = Trainer(...)
288+
289+
# Create a Tuner
290+
tuner = Tuner(trainer)
308291
309292
# finds learning rate automatically
310293
# sets hparams.lr or hparams.learning_rate to that learning rate
311-
trainer = Trainer(auto_lr_find=True)
312-
313-
trainer.tune(model)
294+
tuner.lr_find(model)
314295
315296
316-
If your model is using an arbitrary value instead of ``self.lr`` or ``self.learning_rate``, set that value as ``auto_lr_find``:
297+
If your model is using an arbitrary value instead of ``self.lr`` or ``self.learning_rate``, set that value in ``lr_find``:
317298

318299
.. code-block:: python
319300
320301
model = LitModel()
302+
trainer = Trainer(...)
303+
tuner = Tuner(trainer)
321304
322305
# to set to your own hparams.my_value
323-
trainer = Trainer(auto_lr_find="my_value")
306+
tuner.lr_find(model, attr_name="my_value")
324307
325-
trainer.tune(model)
326308
327309
You can also inspect the results of the learning rate finder or just play around
328-
with the parameters of the algorithm. This can be done by invoking the
329-
:meth:`~pytorch_lightning.tuner.tuning.Tuner.lr_find` method. A typical example of this would look like:
310+
with the parameters of the algorithm. A typical example of this would look like:
330311

331312
.. code-block:: python
332313
333314
model = MyModelClass(hparams)
334315
trainer = Trainer()
316+
tuner = Tuner(trainer)
335317
336318
# Run learning rate finder
337-
lr_finder = trainer.tuner.lr_find(model)
319+
lr_finder = tuner.lr_find(model)
338320
339321
# Results can be found in
340322
print(lr_finder.results)

docs/source-pytorch/common/trainer.rst

Lines changed: 3 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -287,69 +287,6 @@ Example::
287287
# no accumulation for epochs 1-4. accumulate 3 for epochs 5-10. accumulate 20 after that
288288
trainer = Trainer(accumulate_grad_batches={5: 3, 10: 20})
289289

290-
auto_scale_batch_size
291-
^^^^^^^^^^^^^^^^^^^^^
292-
293-
.. raw:: html
294-
295-
<video width="50%" max-width="400px" controls
296-
poster="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/thumb/auto_scale%E2%80%A8_batch_size.jpg"
297-
src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/auto_scale_batch_size.mp4"></video>
298-
299-
|
300-
301-
Automatically tries to find the largest batch size that fits into memory,
302-
before any training.
303-
304-
.. code-block:: python
305-
306-
# default used by the Trainer (no scaling of batch size)
307-
trainer = Trainer(auto_scale_batch_size=None)
308-
309-
# run batch size scaling, result overrides hparams.batch_size
310-
trainer = Trainer(auto_scale_batch_size="binsearch")
311-
312-
# call tune to find the batch size
313-
trainer.tune(model)
314-
315-
316-
auto_lr_find
317-
^^^^^^^^^^^^
318-
319-
.. raw:: html
320-
321-
<video width="50%" max-width="400px" controls
322-
poster="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/thumb/auto_lr_find.jpg"
323-
src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/auto_lr_find.mp4"></video>
324-
325-
|
326-
327-
Runs a learning rate finder algorithm (see this `paper <https://arxiv.org/abs/1506.01186>`_)
328-
when calling trainer.tune(), to find optimal initial learning rate.
329-
330-
.. code-block:: python
331-
332-
# default used by the Trainer (no learning rate finder)
333-
trainer = Trainer(auto_lr_find=False)
334-
335-
Example::
336-
337-
# run learning rate finder, results override hparams.learning_rate
338-
trainer = Trainer(auto_lr_find=True)
339-
340-
# call tune to find the lr
341-
trainer.tune(model)
342-
343-
Example::
344-
345-
# run learning rate finder, results override hparams.my_lr_arg
346-
trainer = Trainer(auto_lr_find='my_lr_arg')
347-
348-
# call tune to find the lr
349-
trainer.tune(model)
350-
351-
.. note::
352-
See the :ref:`learning rate finder guide <learning_rate_finder>`.
353290

354291
benchmark
355292
^^^^^^^^^
@@ -617,7 +554,7 @@ impact to subsequent runs. These are the changes enabled:
617554
- The :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` callbacks will not trigger.
618555
- The :class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping` callbacks will not trigger.
619556
- Sets ``limit_{train,val,test,predict}_batches`` to 1 or the number passed.
620-
- Disables the Tuner.
557+
- Disables the tuning callbacks (:class:`~pytorch_lightning.callbacks.batch_size_finder.BatchSizeFinder`, :class:`~pytorch_lightning.callbacks.lr_finder.LearningRateFinder`).
621558
- If using the CLI, the configuration file is not saved.
622559

623560

@@ -1358,12 +1295,6 @@ predict
13581295
.. automethod:: pytorch_lightning.trainer.Trainer.predict
13591296
:noindex:
13601297

1361-
tune
1362-
****
1363-
1364-
.. automethod:: pytorch_lightning.trainer.Trainer.tune
1365-
:noindex:
1366-
13671298

13681299
Properties
13691300
^^^^^^^^^^
@@ -1523,11 +1454,11 @@ execution within that function, and the status of the Trainer.
15231454

15241455
.. code-block:: python
15251456
1526-
# fn in ("fit", "validate", "test", "predict", "tune")
1457+
# fn in ("fit", "validate", "test", "predict")
15271458
trainer.state.fn
15281459
# status in ("initializing", "running", "finished", "interrupted")
15291460
trainer.state.status
1530-
# stage in ("train", "sanity_check", "validate", "test", "predict", "tune")
1461+
# stage in ("train", "sanity_check", "validate", "test", "predict")
15311462
trainer.state.stage
15321463
15331464
should_stop

src/pytorch_lightning/CHANGELOG.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1515

1616
- Added an argument `include_cuda` in `pytorch_lightning.utilities.seed.isolate_rng` to disable managing `torch.cuda`'s rng ([#16423](https://github.com/Lightning-AI/lightning/pull/16423))
1717

18+
- Added `Tuner.lr_find(attr_name=...)` to specify custom learning rate attribute names ([#16462](https://github.com/Lightning-AI/lightning/pull/16462))
19+
1820
- Added an `OnExceptionCheckpoint` callback to save a checkpoint on exception ([#16512](https://github.com/Lightning-AI/lightning/pull/16512))
1921

2022
- Added support for running the `MLFlowLogger` with the `mlflow-skinny` package ([16513](https://github.com/Lightning-AI/lightning/pull/16513))
@@ -31,6 +33,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3133
* `pytorch_lightning.plugins.precision.native_amp` is now `pytorch_lightning.plugins.precision.amp`
3234
* `NativeSyncBatchNorm` is now `TorchSyncBatchNorm`
3335

36+
- Changed the default of `LearningRateFinder(update_attr=...)` and `Tuner.lr_find(update_attr=...)` to `True` ([#16462](https://github.com/Lightning-AI/lightning/pull/16462))
37+
3438
- Renamed the `pl.utilities.exceptions.GracefulExitException` to `SIGTERMException` ([#16501](https://github.com/Lightning-AI/lightning/pull/16501))
3539

3640
### Deprecated
@@ -142,9 +146,14 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
142146

143147
- Removed support for passing a dictionary value to `self.log()` ([#16389](https://github.com/Lightning-AI/lightning/pull/16389))
144148

149+
- Removed `Trainer.model` setter ([#16462](https://github.com/Lightning-AI/lightning/pull/16462))
150+
145151
- Tuner removal
146152
* Removed the deprecated `trainer.tuning` property ([#16379](https://github.com/Lightning-AI/lightning/pull/16379))
147153
* Removed the deprecated `TrainerFn.TUNING` and `RunningStage.TUNING` enums ([#16379](https://github.com/Lightning-AI/lightning/pull/16379))
154+
* Removed `Trainer.tune()` in favor of `Tuner(trainer).{lr_find,scale_batch_size}` ([#16462](https://github.com/Lightning-AI/lightning/pull/16462))
155+
* Removed `Trainer(auto_scale_batch_size=...)` in favor of `Tuner(trainer).scale_batch_size()` ([#16462](https://github.com/Lightning-AI/lightning/pull/16462))
156+
* Removed `Trainer(auto_lr_find=...)` in favor of `Tuner(trainer).lr_find()` ([#16462](https://github.com/Lightning-AI/lightning/pull/16462))
148157

149158
### Fixed
150159

src/pytorch_lightning/callbacks/batch_size_finder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
import pytorch_lightning as pl
2424
from pytorch_lightning.callbacks.callback import Callback
25-
from pytorch_lightning.tuner.batch_size_scaling import scale_batch_size
25+
from pytorch_lightning.tuner.batch_size_scaling import _scale_batch_size
2626
from pytorch_lightning.utilities.exceptions import _TunerExitException, MisconfigurationException
2727
from pytorch_lightning.utilities.parsing import lightning_hasattr
2828
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
@@ -165,7 +165,7 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: O
165165
)
166166

167167
def scale_batch_size(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
168-
new_size = scale_batch_size(
168+
new_size = _scale_batch_size(
169169
trainer,
170170
pl_module,
171171
self._mode,

0 commit comments

Comments
 (0)