Skip to content

Commit ad1e06f

Browse files
authored
Update tuner docs (#15087)
1 parent 1865300 commit ad1e06f

File tree

3 files changed

+83
-28
lines changed

3 files changed

+83
-28
lines changed

docs/source-pytorch/advanced/training_tricks.rst

Lines changed: 49 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -191,29 +191,52 @@ The algorithm in short works by:
191191
Customizing Batch Size Finder
192192
=============================
193193

194-
You can also customize the :class:`~pytorch_lightning.callbacks.batch_size_finder.BatchSizeFinder` callback to run at different epochs. This feature is useful while fine-tuning models since
195-
you can't always use the same batch size after unfreezing the backbone.
194+
1. You can also customize the :class:`~pytorch_lightning.callbacks.batch_size_finder.BatchSizeFinder` callback to run
195+
at different epochs. This feature is useful while fine-tuning models since you can't always use the same batch size after
196+
unfreezing the backbone.
196197

197198
.. code-block:: python
198199
199-
from pytorch_lightning.callbacks import BatchSizeFinder
200+
from pytorch_lightning.callbacks import BatchSizeFinder
200201
201202
202-
class FineTuneBatchSizeFinder(BatchSizeFinder):
203-
def __init__(self, *args, **kwargs):
204-
super().__init__(*args, **kwargs)
205-
self.milestones = milestones
203+
class FineTuneBatchSizeFinder(BatchSizeFinder):
204+
def __init__(self, milestones, *args, **kwargs):
205+
super().__init__(*args, **kwargs)
206+
self.milestones = milestones
206207
207-
def on_fit_start(self, *args, **kwargs):
208-
return
208+
def on_fit_start(self, *args, **kwargs):
209+
return
209210
210-
def on_train_epoch_start(self, trainer, pl_module):
211-
if trainer.current_epoch in self.milestones or trainer.current_epoch == 0:
212-
self.scale_batch_size(trainer, pl_module)
211+
def on_train_epoch_start(self, trainer, pl_module):
212+
if trainer.current_epoch in self.milestones or trainer.current_epoch == 0:
213+
self.scale_batch_size(trainer, pl_module)
213214
214215
215-
trainer = Trainer(callbacks=[FineTuneBatchSizeFinder(milestones=(5, 10))])
216-
trainer.fit(...)
216+
trainer = Trainer(callbacks=[FineTuneBatchSizeFinder(milestones=(5, 10))])
217+
trainer.fit(...)
218+
219+
220+
2. Run batch size finder for ``validate``/``test``/``predict``.
221+
222+
.. code-block:: python
223+
224+
from pytorch_lightning.callbacks import BatchSizeFinder
225+
226+
227+
class EvalBatchSizeFinder(BatchSizeFinder):
228+
def __init__(self, *args, **kwargs):
229+
super().__init__(*args, **kwargs)
230+
231+
def on_fit_start(self, *args, **kwargs):
232+
return
233+
234+
def on_test_start(self, trainer, pl_module):
235+
self.scale_batch_size(trainer, pl_module)
236+
237+
238+
trainer = Trainer(callbacks=[EvalBatchSizeFinder()])
239+
trainer.test(...)
217240
218241
219242
----------
@@ -336,24 +359,24 @@ You can also customize the :class:`~pytorch_lightning.callbacks.lr_finder.Learni
336359

337360
.. code-block:: python
338361
339-
from pytorch_lightning.callbacks import LearningRateFinder
362+
from pytorch_lightning.callbacks import LearningRateFinder
340363
341364
342-
class FineTuneLearningRateFinder(LearningRateFinder):
343-
def __init__(self, milestones=(5, 10), *args, **kwargs):
344-
super().__init__(*args, **kwargs)
345-
self.milestones = milestones
365+
class FineTuneLearningRateFinder(LearningRateFinder):
366+
def __init__(self, milestones, *args, **kwargs):
367+
super().__init__(*args, **kwargs)
368+
self.milestones = milestones
346369
347-
def on_fit_start(self, *args, **kwargs):
348-
return
370+
def on_fit_start(self, *args, **kwargs):
371+
return
349372
350-
def on_train_epoch_start(self, trainer, pl_module):
351-
if trainer.current_epoch in self.milestones or trainer.current_epoch == 0:
352-
self.lr_find(trainer, pl_module)
373+
def on_train_epoch_start(self, trainer, pl_module):
374+
if trainer.current_epoch in self.milestones or trainer.current_epoch == 0:
375+
self.lr_find(trainer, pl_module)
353376
354377
355-
trainer = Trainer(callbacks=[FineTuneLearningRateFinder(milestones=(5, 10))])
356-
trainer.fit(...)
378+
trainer = Trainer(callbacks=[FineTuneLearningRateFinder(milestones=(5, 10))])
379+
trainer.fit(...)
357380
358381
359382
.. figure:: ../_static/images/trainer/lr_finder.png

src/pytorch_lightning/callbacks/batch_size_finder.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,14 @@ class BatchSizeFinder(Callback):
6060
6161
Example::
6262
63+
# 1. Customize the BatchSizeFinder callback to run at different epochs. This feature is
64+
# useful while fine-tuning models since you can't always use the same batch size after
65+
# unfreezing the backbone.
66+
from pytorch_lightning.callbacks import BatchSizeFinder
67+
68+
6369
class FineTuneBatchSizeFinder(BatchSizeFinder):
64-
def __init__(self, *args, **kwargs):
70+
def __init__(self, milestones, *args, **kwargs):
6571
super().__init__(*args, **kwargs)
6672
self.milestones = milestones
6773
@@ -75,6 +81,26 @@ def on_train_epoch_start(self, trainer, pl_module):
7581
7682
trainer = Trainer(callbacks=[FineTuneBatchSizeFinder(milestones=(5, 10))])
7783
trainer.fit(...)
84+
85+
Example::
86+
87+
# 2. Run batch size finder for validate/test/predict.
88+
from pytorch_lightning.callbacks import BatchSizeFinder
89+
90+
91+
class EvalBatchSizeFinder(BatchSizeFinder):
92+
def __init__(self, *args, **kwargs):
93+
super().__init__(*args, **kwargs)
94+
95+
def on_fit_start(self, *args, **kwargs):
96+
return
97+
98+
def on_test_start(self, trainer, pl_module):
99+
self.scale_batch_size(trainer, pl_module)
100+
101+
102+
trainer = Trainer(callbacks=[EvalBatchSizeFinder()])
103+
trainer.test(...)
78104
"""
79105

80106
SUPPORTED_MODES = ("power", "binsearch")

src/pytorch_lightning/callbacks/lr_finder.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,13 @@ class LearningRateFinder(Callback):
5050
5151
Example::
5252
53+
# Customize LearningRateFinder callback to run at different epochs.
54+
# This feature is useful while fine-tuning models.
55+
from pytorch_lightning.callbacks import LearningRateFinder
56+
57+
5358
class FineTuneLearningRateFinder(LearningRateFinder):
54-
def __init__(self, milestones=(5, 10), *args, **kwargs):
59+
def __init__(self, milestones, *args, **kwargs):
5560
super().__init__(*args, **kwargs)
5661
self.milestones = milestones
5762
@@ -62,6 +67,7 @@ def on_train_epoch_start(self, trainer, pl_module):
6267
if trainer.current_epoch in self.milestones or trainer.current_epoch == 0:
6368
self.lr_find(trainer, pl_module)
6469
70+
6571
trainer = Trainer(callbacks=[FineTuneLearningRateFinder(milestones=(5, 10))])
6672
trainer.fit(...)
6773

0 commit comments

Comments
 (0)