Skip to content

Commit 28911d4

Browse files
committed
refactor finetuningscheduler to remove core pl trainer convenience attribute and EarlyStopping modification. FTSEarlyStopping extension of EarlyStopping added. FTS dependencies managed by a callbackresolvermixin
1 parent ae7d632 commit 28911d4

File tree

11 files changed

+343
-176
lines changed

11 files changed

+343
-176
lines changed

pl_examples/basic_examples/config/fts/fts_explicit.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
trainer:
22
callbacks:
3-
- class_path: pytorch_lightning.callbacks.FinetuningScheduler
3+
- class_path: pytorch_lightning.callbacks.finetuning_scheduler.FinetuningScheduler
44
init_args:
55
ft_schedule: ./pl_examples/basic_examples/config/fts/RteBoolqModule_ft_schedule_albert_base.yaml
66
- class_path: pytorch_lightning.callbacks.finetuning_scheduler.FTSCheckpoint
77
init_args:
88
save_top_k: 5
99
monitor: val_loss
1010
verbose: true
11-
- class_path: pytorch_lightning.callbacks.EarlyStopping
11+
- class_path: pytorch_lightning.callbacks.finetuning_scheduler.FTSEarlyStopping
1212
init_args:
1313
monitor: val_loss
1414
min_delta: 0.001

pl_examples/basic_examples/config/fts/fts_implicit.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
trainer:
22
callbacks:
3-
- class_path: pytorch_lightning.callbacks.FinetuningScheduler
3+
- class_path: pytorch_lightning.callbacks.finetuning_scheduler.FinetuningScheduler
44
- class_path: pytorch_lightning.callbacks.finetuning_scheduler.FTSCheckpoint
55
init_args:
66
save_top_k: 5
77
monitor: val_loss
88
verbose: true
9-
- class_path: pytorch_lightning.callbacks.EarlyStopping
9+
- class_path: pytorch_lightning.callbacks.finetuning_scheduler.FTSEarlyStopping
1010
init_args:
1111
monitor: val_loss
1212
min_delta: 0.001

pl_examples/basic_examples/config/fts/nofts_baseline.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
trainer:
22
callbacks:
3-
- class_path: pytorch_lightning.callbacks.EarlyStopping
3+
- class_path: pytorch_lightning.callbacks.early_stopping.EarlyStopping
44
init_args:
55
monitor: val_loss
66
min_delta: 0.001

pl_examples/basic_examples/fts_superglue.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343

4444
import pytorch_lightning as pl
4545
from pl_examples import _HF_AVAILABLE
46+
from pytorch_lightning.callbacks.finetuning_scheduler.fts import FinetuningScheduler
4647
from pytorch_lightning.utilities import rank_zero_warn
4748
from pytorch_lightning.utilities.cli import instantiate_class, LightningCLI
4849

@@ -116,6 +117,7 @@ def __init__(
116117
self.model.config.update(self.model_cfg) # apply model config overrides
117118
self.metric = datasets.load_metric("super_glue", self.task_name, experiment_id=self.experiment_id)
118119
self.no_decay = ["bias", "LayerNorm.weight"]
120+
self.finetuningscheduler_callback = None
119121

120122
def forward(self, **inputs):
121123
return self.model(**inputs)
@@ -128,8 +130,8 @@ def training_step(self, batch, batch_idx):
128130
def training_epoch_end(self, outputs: List[Any]) -> None:
129131
loss = torch.stack([x["loss"] for x in outputs]).mean()
130132
self.log("train_loss", loss, prog_bar=True, sync_dist=True)
131-
if self.trainer.finetuning_scheduler_callback:
132-
self.log("finetuning_schedule_depth", self.trainer.finetuning_scheduler_callback.curr_depth)
133+
if self.finetuningscheduler_callback:
134+
self.log("finetuning_schedule_depth", self.finetuningscheduler_callback.curr_depth)
133135

134136
def validation_step(self, batch, batch_idx, dataloader_idx=0):
135137
outputs = self(**batch)
@@ -183,6 +185,12 @@ def configure_optimizers(self):
183185
scheduler = {"scheduler": instantiate_class(optimizer, self.lr_scheduler_init), **self.pl_lrs_cfg}
184186
return [optimizer], [scheduler]
185187

188+
def configure_callbacks(self):
189+
found_fts = [c for c in self.trainer.callbacks if isinstance(c, FinetuningScheduler)]
190+
if found_fts:
191+
self.finetuningscheduler_callback = found_fts[0]
192+
return super().configure_callbacks()
193+
186194

187195
class RteBoolqDataModule(pl.LightningDataModule):
188196
"""A :class:`~pytorch_lighting.core.LightningDataModule` for using either the RTE or BoolQ `SuperGLUE Hugging

pytorch_lightning/callbacks/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from pytorch_lightning.callbacks.device_stats_monitor import DeviceStatsMonitor
1616
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
1717
from pytorch_lightning.callbacks.finetuning import BackboneFinetuning, BaseFinetuning
18-
from pytorch_lightning.callbacks.finetuning_scheduler import FinetuningScheduler # idso
1918
from pytorch_lightning.callbacks.gpu_stats_monitor import GPUStatsMonitor
2019
from pytorch_lightning.callbacks.gradient_accumulation_scheduler import GradientAccumulationScheduler
2120
from pytorch_lightning.callbacks.lambda_function import LambdaCallback
@@ -32,11 +31,10 @@
3231
from pytorch_lightning.callbacks.xla_stats_monitor import XLAStatsMonitor
3332

3433
__all__ = [
35-
"FinetuningScheduler",
3634
"BackboneFinetuning",
3735
"BaseFinetuning",
38-
"DeviceStatsMonitor",
3936
"Callback",
37+
"DeviceStatsMonitor",
4038
"EarlyStopping",
4139
"GPUStatsMonitor",
4240
"XLAStatsMonitor",

pytorch_lightning/callbacks/early_stopping.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,6 @@ def __init__(
112112
self.divergence_threshold = divergence_threshold
113113
self.wait_count = 0
114114
self.stopped_epoch = 0
115-
self.es_phase_complete = True
116-
self.final_phase = True
117115
self._check_on_train_epoch_end = check_on_train_epoch_end
118116

119117
if self.mode not in self.mode_dict:
@@ -240,15 +238,12 @@ def _evaluate_stopping_criteria(self, current: torch.Tensor) -> Tuple[bool, Opti
240238
else:
241239
self.wait_count += 1
242240
if self.wait_count >= self.patience:
243-
if self.final_phase:
244-
should_stop = True
245-
reason = (
246-
f"Monitored metric {self.monitor} did not improve in the last {self.wait_count} records."
247-
f" Best score: {self.best_score:.3f}. Signaling Trainer to stop."
248-
)
249-
else:
250-
self.es_phase_complete = True
251-
self.wait_count = 0
241+
should_stop = True
242+
reason = (
243+
f"Monitored metric {self.monitor} did not improve in the last {self.wait_count} records."
244+
f" Best score: {self.best_score:.3f}. Signaling Trainer to stop."
245+
)
246+
252247
return should_stop, reason
253248

254249
def _improvement_message(self, current: torch.Tensor) -> str:

pytorch_lightning/callbacks/finetuning_scheduler/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121
from pytorch_lightning.callbacks.finetuning_scheduler.fts_supporters import ( # noqa: F401 # isort: skip
2222
FTSState,
2323
FTSCheckpoint,
24+
FTSEarlyStopping,
2425
SchedulingMixin,
2526
CallbackDepMixin,
27+
CallbackResolverMixin,
2628
)
2729
from pytorch_lightning.callbacks.finetuning_scheduler.fts import FinetuningScheduler # noqa: F401

pytorch_lightning/callbacks/finetuning_scheduler/fts.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,10 @@ class FinetuningScheduler(BaseFinetuning, SchedulingMixin, CallbackDepMixin):
4343
unfreezing of models via a finetuning schedule that is either implicitly generated (the default) or explicitly
4444
provided by the user (more computationally efficient).
4545
46-
Finetuning phase transitions are driven by :class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping`
47-
criteria, user-specified epoch transitions or a composition of the two (the default mode). A
46+
Finetuning phase transitions are driven by
47+
:class:`~pytorch_lightning.callbacks.finetuning_scheduler.fts_supporters.FTSEarlyStopping` criteria (a multi-phase
48+
extension of :class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping`), user-specified epoch transitions
49+
or a composition of the two (the default mode). A
4850
:class:`~pytorch_lightning.callbacks.finetuning_scheduler.fts.FinetuningScheduler` training session completes
4951
when the final phase of the schedule has its stopping criteria met. See
5052
:ref:`Early Stopping<common/early_stopping:Early stopping>` for more details on that callback's configuration.
@@ -103,11 +105,12 @@ def __init__(
103105
and exit without training. Typically used to generate a default schedule that will be adjusted by the
104106
user before training. Defaults to ``False``.
105107
epoch_transitions_only: If ``True``, Use epoch-driven stopping criteria exclusively (rather than composing
106-
:class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping` and epoch-driven criteria which is
107-
the default). If using this mode, an epoch-driven transition (``max_transition_epoch`` >= 0) must be
108-
specified for each phase. If unspecified, ``max_transition_epoch`` defaults to -1 for each phase which
109-
signals the application of :class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping` criteria
110-
only. epoch_transitions_only defaults to ``False``.
108+
:class:`~pytorch_lightning.callbacks.finetuning_scheduler.fts_supporters.FTSEarlyStopping` and
109+
epoch-driven criteria which is the default). If using this mode, an epoch-driven transition
110+
(``max_transition_epoch`` >= 0) must be specified for each phase. If unspecified,
111+
``max_transition_epoch`` defaults to -1 for each phase which signals the application of
112+
:class:`~pytorch_lightning.callbacks.finetuning_scheduler.fts_supporters.FTSEarlyStopping` criteria only
113+
. epoch_transitions_only defaults to ``False``.
111114
112115
Attributes:
113116
_fts_state: The internal finetuning scheduler state.
@@ -268,7 +271,8 @@ def on_before_accelerator_backend_setup(self, trainer: "pl.Trainer", pl_module:
268271
"""Before setting up the accelerator environment:
269272
Dump the default finetuning schedule
270273
OR
271-
1. configure the :class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping` callback (if relevant)
274+
1. configure the :class:`~pytorch_lightning.callbacks.finetuning_scheduler.fts_supporters.FTSEarlyStopping`
275+
callback (if relevant)
272276
2. initialize the :attr:`~pytorch_lightning.callbacks.finetuning_scheduler.fts.FinetuningScheduler._fts_state`
273277
3. freeze the target :class:`~pytorch_lightning.core.lightning.LightningModule` parameters
274278
@@ -386,8 +390,8 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
386390

387391
def should_transition(self, trainer: "pl.Trainer") -> bool:
388392
"""Phase transition logic is contingent on whether we are composing
389-
:class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping` criteria with epoch-driven transition
390-
constraints or exclusively using epoch-driven transition scheduling. (i.e.,
393+
:class:`~pytorch_lightning.callbacks.finetuning_scheduler.fts_supporters.FTSEarlyStopping` criteria with
394+
epoch-driven transition constraints or exclusively using epoch-driven transition scheduling. (i.e.,
391395
:attr:`~pytorch_lightning.callbacks.finetuning_scheduler.fts.FinetuningScheduler.epoch_transitions_only` is
392396
``True``)
393397
@@ -400,7 +404,7 @@ def should_transition(self, trainer: "pl.Trainer") -> bool:
400404
if self.depth_remaining > 0
401405
else self.pl_module.trainer.fit_loop.max_epochs
402406
)
403-
if not self.epoch_transitions_only: # if we're considering EarlyStopping criteria
407+
if not self.epoch_transitions_only: # if we're considering FTSEarlyStopping criteria
404408
epoch_driven_transition = (
405409
True
406410
if not self.pl_module.trainer.early_stopping_callback.final_phase

0 commit comments

Comments
 (0)