@@ -43,8 +43,10 @@ class FinetuningScheduler(BaseFinetuning, SchedulingMixin, CallbackDepMixin):
43
43
unfreezing of models via a finetuning schedule that is either implicitly generated (the default) or explicitly
44
44
provided by the user (more computationally efficient).
45
45
46
- Finetuning phase transitions are driven by :class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping`
47
- criteria, user-specified epoch transitions or a composition of the two (the default mode). A
46
+ Finetuning phase transitions are driven by
47
+ :class:`~pytorch_lightning.callbacks.finetuning_scheduler.fts_supporters.FTSEarlyStopping` criteria (a multi-phase
48
+ extension of :class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping`), user-specified epoch transitions
49
+ or a composition of the two (the default mode). A
48
50
:class:`~pytorch_lightning.callbacks.finetuning_scheduler.fts.FinetuningScheduler` training session completes
49
51
when the final phase of the schedule has its stopping criteria met. See
50
52
:ref:`Early Stopping<common/early_stopping:Early stopping>` for more details on that callback's configuration.
@@ -103,11 +105,12 @@ def __init__(
103
105
and exit without training. Typically used to generate a default schedule that will be adjusted by the
104
106
user before training. Defaults to ``False``.
105
107
epoch_transitions_only: If ``True``, Use epoch-driven stopping criteria exclusively (rather than composing
106
- :class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping` and epoch-driven criteria which is
107
- the default). If using this mode, an epoch-driven transition (``max_transition_epoch`` >= 0) must be
108
- specified for each phase. If unspecified, ``max_transition_epoch`` defaults to -1 for each phase which
109
- signals the application of :class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping` criteria
110
- only. epoch_transitions_only defaults to ``False``.
108
+ :class:`~pytorch_lightning.callbacks.finetuning_scheduler.fts_supporters.FTSEarlyStopping` and
109
+ epoch-driven criteria which is the default). If using this mode, an epoch-driven transition
110
+ (``max_transition_epoch`` >= 0) must be specified for each phase. If unspecified,
111
+ ``max_transition_epoch`` defaults to -1 for each phase which signals the application of
112
+ :class:`~pytorch_lightning.callbacks.finetuning_scheduler.fts_supporters.FTSEarlyStopping` criteria only
113
+ . epoch_transitions_only defaults to ``False``.
111
114
112
115
Attributes:
113
116
_fts_state: The internal finetuning scheduler state.
@@ -268,7 +271,8 @@ def on_before_accelerator_backend_setup(self, trainer: "pl.Trainer", pl_module:
268
271
"""Before setting up the accelerator environment:
269
272
Dump the default finetuning schedule
270
273
OR
271
- 1. configure the :class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping` callback (if relevant)
274
+ 1. configure the :class:`~pytorch_lightning.callbacks.finetuning_scheduler.fts_supporters.FTSEarlyStopping`
275
+ callback (if relevant)
272
276
2. initialize the :attr:`~pytorch_lightning.callbacks.finetuning_scheduler.fts.FinetuningScheduler._fts_state`
273
277
3. freeze the target :class:`~pytorch_lightning.core.lightning.LightningModule` parameters
274
278
@@ -386,8 +390,8 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
386
390
387
391
def should_transition (self , trainer : "pl.Trainer" ) -> bool :
388
392
"""Phase transition logic is contingent on whether we are composing
389
- :class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping ` criteria with epoch-driven transition
390
- constraints or exclusively using epoch-driven transition scheduling. (i.e.,
393
+ :class:`~pytorch_lightning.callbacks.finetuning_scheduler.fts_supporters.FTSEarlyStopping ` criteria with
394
+ epoch-driven transition constraints or exclusively using epoch-driven transition scheduling. (i.e.,
391
395
:attr:`~pytorch_lightning.callbacks.finetuning_scheduler.fts.FinetuningScheduler.epoch_transitions_only` is
392
396
``True``)
393
397
@@ -400,7 +404,7 @@ def should_transition(self, trainer: "pl.Trainer") -> bool:
400
404
if self .depth_remaining > 0
401
405
else self .pl_module .trainer .fit_loop .max_epochs
402
406
)
403
- if not self .epoch_transitions_only : # if we're considering EarlyStopping criteria
407
+ if not self .epoch_transitions_only : # if we're considering FTSEarlyStopping criteria
404
408
epoch_driven_transition = (
405
409
True
406
410
if not self .pl_module .trainer .early_stopping_callback .final_phase
0 commit comments