Skip to content

Commit 0ab6d8b

Browse files
committed
update again
1 parent 69a2d5a commit 0ab6d8b

File tree

3 files changed

+18
-37
lines changed

3 files changed

+18
-37
lines changed

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,6 @@ def __init__(
188188
auto_insert_metric_name: bool = True,
189189
every_n_train_steps: Optional[int] = None,
190190
every_n_val_epochs: Optional[int] = None,
191-
save_on_train_end: bool = False,
192191
period: Optional[int] = None,
193192
):
194193
super().__init__()
@@ -205,7 +204,6 @@ def __init__(
205204
self.best_model_score = None
206205
self.best_model_path = ""
207206
self.last_model_path = ""
208-
self._save_on_train_end = save_on_train_end
209207

210208
self.__init_monitor_mode(monitor, mode)
211209
self.__init_ckpt_dir(dirpath, filename, save_top_k)
@@ -242,33 +240,6 @@ def on_validation_end(self, trainer, pl_module) -> None:
242240
return
243241
self.save_checkpoint(trainer)
244242

245-
def on_train_end(self, trainer, pl_module) -> None:
246-
"""Save a checkpoint at the very end of training.
247-
248-
This will only save a checkpoint if `save_last` is also enabled
249-
as the monitor metrics produced by training or validation steps or end of epochs
250-
is not guaranteed to be available at this stage.
251-
"""
252-
if self._should_skip_saving_checkpoint(trainer) or not trainer.checkpoint_connector.has_trained:
253-
return
254-
255-
initial_save_last = self.save_last
256-
if self._save_on_train_end and not self.save_last:
257-
rank_zero_warn(
258-
"Requested to save a checkpoint at the end of training but save_last is not set. Temporarily setting save_last=True to save."
259-
)
260-
self.save_last = True
261-
if self.verbose:
262-
rank_zero_info("Saving last checkpoint...")
263-
264-
# as we advance one step at end of training, we use global_step - 1
265-
# to avoid saving duplicates
266-
trainer.global_step -= 1
267-
monitor_candidates = self._monitor_candidates(trainer)
268-
self._save_last_checkpoint(trainer, monitor_candidates)
269-
trainer.global_step += 1
270-
self.save_last = initial_save_last
271-
272243
def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]:
273244
return {
274245
"monitor": self.monitor,

pytorch_lightning/trainer/training_loop.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,12 @@ def on_train_end(self):
9797
return
9898
self._teardown_already_run = True
9999

100+
# trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates
101+
# when a checkpoint was saved at the last step
102+
self.trainer.global_step -= 1
103+
self.check_checkpoint_callback(should_update=True, is_last=True)
104+
self.trainer.global_step += 1
105+
100106
# hook
101107
self.trainer.call_hook("on_train_end")
102108

tests/trainer/test_dataloaders.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -351,14 +351,18 @@ def test_dataloaders_with_limit_train_batches(tmpdir, dataset, limit_train_batch
351351
assert epoch_cb.train_batches_seen == limit_train_batches * epochs
352352

353353

354-
@pytest.mark.parametrize(['dataset', 'limit_val_batches'], [
355-
(RandomDataset(32, 128), 0),
356-
(RandomDataset(32, 128), 10),
357-
(RandomIterableDataset(32, 128), 0),
358-
(RandomIterableDataset(32, 128), 10),
359-
(RandomIterableDatasetWithLen(32, 128), 0),
360-
(RandomIterableDatasetWithLen(32, 128), 10),
361-
])
354+
@pytest.mark.parametrize(
355+
['dataset', 'limit_val_batches'],
356+
[
357+
(RandomDataset(32, 128), 0),
358+
(RandomDataset(32, 128), 10),
359+
(RandomIterableDataset(32, 128), 0),
360+
(RandomIterableDataset(32, 128), 10),
361+
(RandomIterableDatasetWithLen(32, 128), 0),
362+
# TODO: enable this after #6671 is merged
363+
# (RandomIterableDatasetWithLen(32, 128), 10),
364+
]
365+
)
362366
def test_dataloaders_with_limit_val_batches(tmpdir, dataset, limit_val_batches):
363367
"""Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit as number"""
364368

0 commit comments

Comments
 (0)