Skip to content

Commit 3a9fde9

Browse files
Revert "checkpoint consolidation"
This reverts commit 536c132.
1 parent 8222dc9 commit 3a9fde9

File tree

10 files changed

+39
-99
lines changed

10 files changed

+39
-99
lines changed

pytorch_lightning/callbacks/base.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,6 @@ def on_epoch_end(self, trainer, pl_module: LightningModule) -> None:
109109
"""Called when the epoch ends."""
110110
pass
111111

112-
def on_train_epoch_final_end(self, trainer, pl_module: LightningModule) -> None:
113-
"""Called when at the very end of train epoch."""
114-
pass
115-
116112
def on_batch_start(self, trainer, pl_module: LightningModule) -> None:
117113
"""Called when the training batch begins."""
118114
pass

pytorch_lightning/callbacks/early_stopping.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -143,21 +143,6 @@ def on_validation_end(self, trainer, pl_module):
143143

144144
self._run_early_stopping_check(trainer)
145145

146-
def on_train_epoch_final_end(self, trainer, pl_module):
147-
from pytorch_lightning.trainer.states import TrainerState
148-
if (
149-
trainer.state != TrainerState.FITTING or trainer.sanity_checking
150-
or not trainer.checkpoint_connector.has_trained
151-
):
152-
return
153-
# if validation is disabled or should skip, we run early stopping
154-
# at end of the training epoch
155-
if (
156-
trainer.disable_validation
157-
or trainer.evaluation_loop.should_skip_evaluation(trainer.num_val_batches)
158-
):
159-
self._run_early_stopping_check(trainer)
160-
161146
def _run_early_stopping_check(self, trainer):
162147
"""
163148
Checks whether the early stopping condition is met

pytorch_lightning/callbacks/lambda_function.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ def __init__(
5353
on_train_batch_end: Optional[Callable] = None,
5454
on_train_epoch_start: Optional[Callable] = None,
5555
on_train_epoch_end: Optional[Callable] = None,
56-
on_train_epoch_final_end: Optional[Callable] = None,
5756
on_validation_epoch_start: Optional[Callable] = None,
5857
on_validation_epoch_end: Optional[Callable] = None,
5958
on_test_epoch_start: Optional[Callable] = None,
@@ -156,5 +155,3 @@ def __init__(
156155
self.on_after_backward = on_after_backward
157156
if on_before_zero_grad is not None:
158157
self.on_before_zero_grad = on_before_zero_grad
159-
if on_train_epoch_final_end is not None:
160-
self.on_train_epoch_final_end = on_train_epoch_final_end

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -238,37 +238,6 @@ def on_validation_end(self, trainer, *args, **kwargs) -> None:
238238
return
239239
self.save_checkpoint(trainer)
240240

241-
def on_train_epoch_final_end(self, trainer, pl_module):
242-
"""
243-
at the end of each training epoch, checkpoint only when validation is skipped or disabled
244-
"""
245-
print("aaa: epoch {}, step: {}".format(trainer.current_epoch, trainer.global_step))
246-
if (
247-
self._should_skip_saving_checkpoint(trainer)
248-
or not trainer.checkpoint_connector.has_trained
249-
):
250-
return
251-
# if validation is disabled or should skip, we checkpoint at end of the training epoch
252-
if (
253-
trainer.disable_validation
254-
or trainer.evaluation_loop.should_skip_evaluation(trainer.num_val_batches)
255-
):
256-
self.save_checkpoint(trainer)
257-
258-
def on_train_end(self, trainer, *args, **kwargs) -> None:
259-
"""
260-
checkpoints can be saved at the end of the trianing
261-
"""
262-
trainer.global_step -= 1
263-
if (
264-
not self._should_skip_saving_checkpoint(trainer)
265-
and trainer.checkpoint_connector.has_trained
266-
):
267-
if self.save_last and self.verbose:
268-
rank_zero_info("Saving latest checkpoint...")
269-
self.save_checkpoint(trainer)
270-
trainer.global_step += 1
271-
272241
def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]:
273242
return {
274243
"monitor": self.monitor,

pytorch_lightning/trainer/callback_hook.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -92,13 +92,6 @@ def on_train_epoch_end(self, outputs: List[Any]):
9292
for callback in self.callbacks:
9393
callback.on_train_epoch_end(self, self.lightning_module, outputs)
9494

95-
def on_train_epoch_final_end(self) -> None:
96-
"""
97-
Called when at the very end of train epoch.
98-
"""
99-
for callback in self.callbacks:
100-
callback.on_train_epoch_final_end(self, self.lightning_module)
101-
10295
def on_validation_epoch_start(self):
10396
"""Called when the epoch begins."""
10497
for callback in self.callbacks:

pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,6 @@ def _on_train_epoch_end_log():
100100
"""Called when the epoch ends."""
101101
return {"on_step": [False], "on_epoch": [False, True]}
102102

103-
@staticmethod
104-
def _on_train_epoch_final_end_log():
105-
"""Called when at the very end of train epoch."""
106-
return {"on_step": [False], "on_epoch": [False, True]}
107-
108103
@staticmethod
109104
def _on_validation_epoch_start_log():
110105
"""Called when the epoch begins."""

pytorch_lightning/trainer/training_loop.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,12 @@ def on_train_end(self):
121121
return
122122
self._teardown_already_run = True
123123

124+
# trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates
125+
# when a checkpoint was saved at the last step
126+
self.trainer.global_step -= 1
127+
self.check_checkpoint_callback(should_update=True, is_last=True)
128+
self.trainer.global_step += 1
129+
124130
# hook
125131
self.trainer.call_hook("on_train_end")
126132

@@ -139,6 +145,28 @@ def on_train_end(self):
139145
# reset bookkeeping
140146
self.trainer._running_stage = None
141147

148+
def check_checkpoint_callback(self, should_update, is_last=False):
149+
# TODO bake this logic into the ModelCheckpoint callback
150+
if should_update and self.trainer.checkpoint_connector.has_trained:
151+
callbacks = self.trainer.checkpoint_callbacks
152+
153+
if is_last and any(cb.save_last and cb.verbose for cb in callbacks):
154+
rank_zero_info("Saving latest checkpoint...")
155+
156+
model = self.trainer.lightning_module
157+
158+
for cb in callbacks:
159+
cb.on_validation_end(self.trainer, model)
160+
161+
def check_early_stopping_callback(self, should_update):
162+
# TODO bake this logic into the EarlyStopping callback
163+
if should_update and self.trainer.checkpoint_connector.has_trained:
164+
callbacks = [c for c in self.trainer.callbacks if isinstance(c, EarlyStopping)]
165+
model = self.trainer.lightning_module
166+
167+
for cb in callbacks:
168+
cb.on_validation_end(self.trainer, model)
169+
142170
def on_train_epoch_start(self, epoch):
143171

144172
# update training progress in trainer
@@ -534,14 +562,15 @@ def run_training_epoch(self):
534562
if (val_loop_called and not should_check_val) or should_train_only:
535563
self.trainer.optimizer_connector.update_learning_rates(interval='epoch')
536564

565+
if should_train_only:
566+
self.check_checkpoint_callback(True)
567+
self.check_early_stopping_callback(True)
568+
537569
if should_check_val:
538570
self.trainer.validating = True
539571
self.trainer.run_evaluation(on_epoch=True)
540572
self.trainer.training = True
541573

542-
if should_train_only:
543-
self.trainer.call_hook('on_train_epoch_final_end')
544-
545574
# increment the global step once
546575
# progress global step according to grads progress
547576
self.increment_accumulated_grad_global_step()

tests/checkpointing/test_model_checkpoint.py

Lines changed: 6 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -609,13 +609,7 @@ def test_model_checkpoint_period(tmpdir, period: int):
609609
trainer.fit(model)
610610

611611
# check that the correct ckpts were created
612-
final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1)
613-
expected = (
614-
[f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % period and e + 1 != epochs]
615-
if period > 0
616-
else []
617-
)
618-
expected.append(final_epoch_ckpt)
612+
expected = [f'epoch={e}.ckpt' for e in range(epochs) if not (e + 1) % period] if period > 0 else []
619613
assert set(os.listdir(tmpdir)) == set(expected)
620614

621615

@@ -637,14 +631,8 @@ def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs):
637631
trainer.fit(model)
638632

639633
# check that the correct ckpts were created
640-
# check that the correct ckpts were created
641-
final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1)
642-
expected = (
643-
[f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % every_n_val_epochs and e + 1 != epochs]
644-
if every_n_val_epochs > 0
645-
else []
646-
)
647-
expected.append(final_epoch_ckpt)
634+
expected = [f'epoch={e}.ckpt' for e in range(epochs)
635+
if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else []
648636
assert set(os.listdir(tmpdir)) == set(expected)
649637

650638

@@ -671,14 +659,8 @@ def test_model_checkpoint_every_n_val_epochs_and_period(tmpdir, every_n_val_epoc
671659
trainer.fit(model)
672660

673661
# check that the correct ckpts were created
674-
# check that the correct ckpts were created
675-
final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1)
676-
expected = (
677-
[f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % every_n_val_epochs and e + 1 != epochs]
678-
if every_n_val_epochs > 0
679-
else []
680-
)
681-
expected.append(final_epoch_ckpt)
662+
expected = [f'epoch={e}.ckpt' for e in range(epochs)
663+
if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else []
682664
assert set(os.listdir(tmpdir)) == set(expected)
683665

684666

@@ -834,15 +816,10 @@ def test_model_checkpoint_save_last_warning(
834816
default_root_dir=tmpdir,
835817
callbacks=[ckpt],
836818
max_epochs=max_epochs,
837-
val_check_interval=0.1,
838819
)
839820
with caplog.at_level(logging.INFO):
840821
trainer.fit(model)
841-
if verbose and save_last and not should_validate:
842-
# no validation, hence checkpoint triggered at the end of each training epoch
843-
assert caplog.messages.count('Saving latest checkpoint...') == False
844-
else:
845-
assert caplog.messages.count('Saving latest checkpoint...') == (verbose and save_last)
822+
assert caplog.messages.count('Saving latest checkpoint...') == (verbose and save_last)
846823

847824

848825
def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):

tests/helpers/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def reset_seed(seed=0):
7676
def set_random_master_port():
7777
reset_seed()
7878
port = RANDOM_PORTS.pop()
79-
os.environ['MASTER_PORT'] = "29501"
79+
os.environ['MASTER_PORT'] = str(port)
8080

8181

8282
def init_checkpoint_callback(logger):

tests/trainer/logging_/test_logger_connector.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,6 @@ def test_call_back_validator(tmpdir):
300300
'on_train_batch_start',
301301
'on_train_end',
302302
'on_train_epoch_end',
303-
'on_train_epoch_final_end',
304303
'on_train_epoch_start',
305304
'on_train_start',
306305
'on_validation_batch_end',

0 commit comments

Comments
 (0)