Skip to content

Commit dec43b9

Browse files
committed
Remove ModelCheckpoint.on_train_end
1 parent f663cfd commit dec43b9

File tree

3 files changed

+2
-40
lines changed

3 files changed

+2
-40
lines changed

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -315,19 +315,6 @@ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModul
315315
return
316316
self.save_checkpoint(trainer)
317317

318-
def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
319-
"""Save a checkpoint when training stops.
320-
321-
This will only save a checkpoint if `save_last` is also enabled as the monitor metrics logged during
322-
training/validation steps or end of epochs are not guaranteed to be available at this stage.
323-
"""
324-
if self._should_skip_saving_checkpoint(trainer) or not self.save_last:
325-
return
326-
if self.verbose:
327-
rank_zero_info("Saving latest checkpoint...")
328-
monitor_candidates = self._monitor_candidates(trainer, trainer.current_epoch, trainer.global_step)
329-
self._save_last_checkpoint(trainer, monitor_candidates)
330-
331318
def on_save_checkpoint(
332319
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
333320
) -> Dict[str, Any]:

tests/checkpointing/test_checkpoint_callback_frequency.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ def training_step(self, batch, batch_idx):
8181
trainer.fit(model)
8282

8383
if save_last:
84-
# last epochs are saved every step (so double the save calls) and once `on_train_end`
85-
expected = expected * 2 + 1
84+
# last epochs are saved every step (so double the save calls)
85+
expected = expected * 2
8686
assert save_mock.call_count == expected
8787

8888

tests/checkpointing/test_model_checkpoint.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import logging
1514
import math
1615
import os
1716
import pickle
@@ -776,30 +775,6 @@ def test_default_checkpoint_behavior(tmpdir):
776775
assert ckpts[0] == "epoch=2-step=15.ckpt"
777776

778777

779-
@pytest.mark.parametrize("max_epochs", [1, 2])
780-
@pytest.mark.parametrize("should_validate", [True, False])
781-
@pytest.mark.parametrize("save_last", [True, False])
782-
@pytest.mark.parametrize("verbose", [True, False])
783-
def test_model_checkpoint_save_last_warning(
784-
tmpdir, caplog, max_epochs: int, should_validate: bool, save_last: bool, verbose: bool
785-
):
786-
"""Tests 'Saving latest checkpoint...' log."""
787-
# set a high `every_n_epochs` to avoid saving in `on_train_epoch_end`. the message is only printed `on_train_end`
788-
# but it would get skipped because it got already saved in `on_train_epoch_end` for the same global step
789-
ckpt = ModelCheckpoint(dirpath=tmpdir, save_top_k=0, save_last=save_last, verbose=verbose, every_n_epochs=123)
790-
trainer = Trainer(
791-
default_root_dir=tmpdir,
792-
callbacks=[ckpt],
793-
max_epochs=max_epochs,
794-
limit_train_batches=1,
795-
limit_val_batches=int(should_validate),
796-
)
797-
model = BoringModel()
798-
with caplog.at_level(logging.INFO):
799-
trainer.fit(model)
800-
assert caplog.messages.count("Saving latest checkpoint...") == (verbose and save_last)
801-
802-
803778
def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
804779
"""Tests that the save_last checkpoint contains the latest information."""
805780
seed_everything(100)

0 commit comments

Comments
 (0)