Skip to content

Commit 536c132

Browse files
author
Shuying Sun
committed
checkpoint consolidation
1 parent 80cfbff commit 536c132

File tree

10 files changed

+99
-39
lines changed

10 files changed

+99
-39
lines changed

pytorch_lightning/callbacks/base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,10 @@ def on_epoch_end(self, trainer, pl_module: LightningModule) -> None:
109109
"""Called when the epoch ends."""
110110
pass
111111

112+
def on_train_epoch_final_end(self, trainer, pl_module: LightningModule) -> None:
113+
"""Called when at the very end of train epoch."""
114+
pass
115+
112116
def on_batch_start(self, trainer, pl_module: LightningModule) -> None:
113117
"""Called when the training batch begins."""
114118
pass

pytorch_lightning/callbacks/early_stopping.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,21 @@ def on_validation_end(self, trainer, pl_module):
143143

144144
self._run_early_stopping_check(trainer)
145145

146+
def on_train_epoch_final_end(self, trainer, pl_module):
147+
from pytorch_lightning.trainer.states import TrainerState
148+
if (
149+
trainer.state != TrainerState.FITTING or trainer.sanity_checking
150+
or not trainer.checkpoint_connector.has_trained
151+
):
152+
return
153+
# if validation is disabled or should skip, we run early stopping
154+
# at end of the training epoch
155+
if (
156+
trainer.disable_validation
157+
or trainer.evaluation_loop.should_skip_evaluation(trainer.num_val_batches)
158+
):
159+
self._run_early_stopping_check(trainer)
160+
146161
def _run_early_stopping_check(self, trainer):
147162
"""
148163
Checks whether the early stopping condition is met

pytorch_lightning/callbacks/lambda_function.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def __init__(
5353
on_train_batch_end: Optional[Callable] = None,
5454
on_train_epoch_start: Optional[Callable] = None,
5555
on_train_epoch_end: Optional[Callable] = None,
56+
on_train_epoch_final_end: Optional[Callable] = None,
5657
on_validation_epoch_start: Optional[Callable] = None,
5758
on_validation_epoch_end: Optional[Callable] = None,
5859
on_test_epoch_start: Optional[Callable] = None,
@@ -155,3 +156,5 @@ def __init__(
155156
self.on_after_backward = on_after_backward
156157
if on_before_zero_grad is not None:
157158
self.on_before_zero_grad = on_before_zero_grad
159+
if on_train_epoch_final_end is not None:
160+
self.on_train_epoch_final_end = on_train_epoch_final_end

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,37 @@ def on_validation_end(self, trainer, *args, **kwargs) -> None:
238238
return
239239
self.save_checkpoint(trainer)
240240

241+
def on_train_epoch_final_end(self, trainer, pl_module):
242+
"""
243+
at the end of each training epoch, checkpoint only when validation is skipped or disabled
244+
"""
245+
print("aaa: epoch {}, step: {}".format(trainer.current_epoch, trainer.global_step))
246+
if (
247+
self._should_skip_saving_checkpoint(trainer)
248+
or not trainer.checkpoint_connector.has_trained
249+
):
250+
return
251+
# if validation is disabled or should skip, we checkpoint at end of the training epoch
252+
if (
253+
trainer.disable_validation
254+
or trainer.evaluation_loop.should_skip_evaluation(trainer.num_val_batches)
255+
):
256+
self.save_checkpoint(trainer)
257+
258+
def on_train_end(self, trainer, *args, **kwargs) -> None:
259+
"""
260+
checkpoints can be saved at the end of the trianing
261+
"""
262+
trainer.global_step -= 1
263+
if (
264+
not self._should_skip_saving_checkpoint(trainer)
265+
and trainer.checkpoint_connector.has_trained
266+
):
267+
if self.save_last and self.verbose:
268+
rank_zero_info("Saving latest checkpoint...")
269+
self.save_checkpoint(trainer)
270+
trainer.global_step += 1
271+
241272
def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]:
242273
return {
243274
"monitor": self.monitor,

pytorch_lightning/trainer/callback_hook.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,13 @@ def on_train_epoch_end(self, outputs: List[Any]):
9292
for callback in self.callbacks:
9393
callback.on_train_epoch_end(self, self.lightning_module, outputs)
9494

95+
def on_train_epoch_final_end(self) -> None:
96+
"""
97+
Called when at the very end of train epoch.
98+
"""
99+
for callback in self.callbacks:
100+
callback.on_train_epoch_final_end(self, self.lightning_module)
101+
95102
def on_validation_epoch_start(self):
96103
"""Called when the epoch begins."""
97104
for callback in self.callbacks:

pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,11 @@ def _on_train_epoch_end_log():
100100
"""Called when the epoch ends."""
101101
return {"on_step": [False], "on_epoch": [False, True]}
102102

103+
@staticmethod
104+
def _on_train_epoch_final_end_log():
105+
"""Called when at the very end of train epoch."""
106+
return {"on_step": [False], "on_epoch": [False, True]}
107+
103108
@staticmethod
104109
def _on_validation_epoch_start_log():
105110
"""Called when the epoch begins."""

pytorch_lightning/trainer/training_loop.py

Lines changed: 3 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -121,12 +121,6 @@ def on_train_end(self):
121121
return
122122
self._teardown_already_run = True
123123

124-
# trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates
125-
# when a checkpoint was saved at the last step
126-
self.trainer.global_step -= 1
127-
self.check_checkpoint_callback(should_update=True, is_last=True)
128-
self.trainer.global_step += 1
129-
130124
# hook
131125
self.trainer.call_hook("on_train_end")
132126

@@ -145,28 +139,6 @@ def on_train_end(self):
145139
# reset bookkeeping
146140
self.trainer._running_stage = None
147141

148-
def check_checkpoint_callback(self, should_update, is_last=False):
149-
# TODO bake this logic into the ModelCheckpoint callback
150-
if should_update and self.trainer.checkpoint_connector.has_trained:
151-
callbacks = self.trainer.checkpoint_callbacks
152-
153-
if is_last and any(cb.save_last and cb.verbose for cb in callbacks):
154-
rank_zero_info("Saving latest checkpoint...")
155-
156-
model = self.trainer.lightning_module
157-
158-
for cb in callbacks:
159-
cb.on_validation_end(self.trainer, model)
160-
161-
def check_early_stopping_callback(self, should_update):
162-
# TODO bake this logic into the EarlyStopping callback
163-
if should_update and self.trainer.checkpoint_connector.has_trained:
164-
callbacks = [c for c in self.trainer.callbacks if isinstance(c, EarlyStopping)]
165-
model = self.trainer.lightning_module
166-
167-
for cb in callbacks:
168-
cb.on_validation_end(self.trainer, model)
169-
170142
def on_train_epoch_start(self, epoch):
171143

172144
# update training progress in trainer
@@ -562,15 +534,14 @@ def run_training_epoch(self):
562534
if (val_loop_called and not should_check_val) or should_train_only:
563535
self.trainer.optimizer_connector.update_learning_rates(interval='epoch')
564536

565-
if should_train_only:
566-
self.check_checkpoint_callback(True)
567-
self.check_early_stopping_callback(True)
568-
569537
if should_check_val:
570538
self.trainer.validating = True
571539
self.trainer.run_evaluation(on_epoch=True)
572540
self.trainer.training = True
573541

542+
if should_train_only:
543+
self.trainer.call_hook('on_train_epoch_final_end')
544+
574545
# increment the global step once
575546
# progress global step according to grads progress
576547
self.increment_accumulated_grad_global_step()

tests/checkpointing/test_model_checkpoint.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -609,7 +609,13 @@ def test_model_checkpoint_period(tmpdir, period: int):
609609
trainer.fit(model)
610610

611611
# check that the correct ckpts were created
612-
expected = [f'epoch={e}.ckpt' for e in range(epochs) if not (e + 1) % period] if period > 0 else []
612+
final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1)
613+
expected = (
614+
[f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % period and e + 1 != epochs]
615+
if period > 0
616+
else []
617+
)
618+
expected.append(final_epoch_ckpt)
613619
assert set(os.listdir(tmpdir)) == set(expected)
614620

615621

@@ -631,8 +637,14 @@ def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs):
631637
trainer.fit(model)
632638

633639
# check that the correct ckpts were created
634-
expected = [f'epoch={e}.ckpt' for e in range(epochs)
635-
if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else []
640+
# check that the correct ckpts were created
641+
final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1)
642+
expected = (
643+
[f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % every_n_val_epochs and e + 1 != epochs]
644+
if every_n_val_epochs > 0
645+
else []
646+
)
647+
expected.append(final_epoch_ckpt)
636648
assert set(os.listdir(tmpdir)) == set(expected)
637649

638650

@@ -659,8 +671,14 @@ def test_model_checkpoint_every_n_val_epochs_and_period(tmpdir, every_n_val_epoc
659671
trainer.fit(model)
660672

661673
# check that the correct ckpts were created
662-
expected = [f'epoch={e}.ckpt' for e in range(epochs)
663-
if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else []
674+
# check that the correct ckpts were created
675+
final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1)
676+
expected = (
677+
[f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % every_n_val_epochs and e + 1 != epochs]
678+
if every_n_val_epochs > 0
679+
else []
680+
)
681+
expected.append(final_epoch_ckpt)
664682
assert set(os.listdir(tmpdir)) == set(expected)
665683

666684

@@ -816,10 +834,15 @@ def test_model_checkpoint_save_last_warning(
816834
default_root_dir=tmpdir,
817835
callbacks=[ckpt],
818836
max_epochs=max_epochs,
837+
val_check_interval=0.1,
819838
)
820839
with caplog.at_level(logging.INFO):
821840
trainer.fit(model)
822-
assert caplog.messages.count('Saving latest checkpoint...') == (verbose and save_last)
841+
if verbose and save_last and not should_validate:
842+
# no validation, hence checkpoint triggered at the end of each training epoch
843+
assert caplog.messages.count('Saving latest checkpoint...') == False
844+
else:
845+
assert caplog.messages.count('Saving latest checkpoint...') == (verbose and save_last)
823846

824847

825848
def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):

tests/helpers/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def reset_seed(seed=0):
7676
def set_random_master_port():
7777
reset_seed()
7878
port = RANDOM_PORTS.pop()
79-
os.environ['MASTER_PORT'] = str(port)
79+
os.environ['MASTER_PORT'] = "29501"
8080

8181

8282
def init_checkpoint_callback(logger):

tests/trainer/logging_/test_logger_connector.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,7 @@ def test_call_back_validator(tmpdir):
300300
'on_train_batch_start',
301301
'on_train_end',
302302
'on_train_epoch_end',
303+
'on_train_epoch_final_end',
303304
'on_train_epoch_start',
304305
'on_train_start',
305306
'on_validation_batch_end',

0 commit comments

Comments
 (0)