Skip to content

Commit 8efb6e9

Browse files
committed
one more
1 parent 0db92d8 commit 8efb6e9

File tree

2 files changed

+9
-15
lines changed

2 files changed

+9
-15
lines changed

src/pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -351,19 +351,12 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
351351

352352
self.best_model_path = state_dict["best_model_path"]
353353

354-
def save_checkpoint(self, trainer: "pl.Trainer") -> None: # pragma: no-cover
355-
"""Performs the main logic around saving a checkpoint.
356-
357-
This method runs on all ranks. It is the responsibility of `trainer.save_checkpoint` to correctly handle the
358-
behaviour in distributed training, i.e., saving only on rank 0 for data parallel use cases.
359-
"""
360-
rank_zero_deprecation(
361-
f"`{self.__class__.__name__}.save_checkpoint()` was deprecated in v1.6 and will be removed in v1.8."
362-
" Instead, you can use `trainer.save_checkpoint()` to manually save a checkpoint."
354+
def save_checkpoint(self, trainer: "pl.Trainer") -> None:
355+
raise NotImplementedError(
356+
f"`{self.__class__.__name__}.save_checkpoint()` was deprecated in v1.6 and is no longer supported"
357+
f" as of 1.8. Please use `trainer.save_checkpoint()` to manually save a checkpoint. This method will be"
358+
f" removed completely in v2.0."
363359
)
364-
monitor_candidates = self._monitor_candidates(trainer)
365-
self._save_topk_checkpoint(trainer, monitor_candidates)
366-
self._save_last_checkpoint(trainer, monitor_candidates)
367360

368361
def _save_topk_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]) -> None:
369362
if self.save_top_k == 0:

tests/tests_pytorch/deprecated_api/test_remove_2-0.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -292,10 +292,11 @@ def on_pretrain_routine_end(self, trainer, pl_module):
292292
trainer.fit(model)
293293

294294

295-
def test_deprecated_mc_save_checkpoint():
295+
def test_v2_0_0_deprecated_mc_save_checkpoint():
296296
mc = ModelCheckpoint()
297297
trainer = Trainer()
298-
with mock.patch.object(trainer, "save_checkpoint"), pytest.deprecated_call(
299-
match=r"ModelCheckpoint.save_checkpoint\(\)` was deprecated in v1.6"
298+
with mock.patch.object(trainer, "save_checkpoint"), pytest.raises(
299+
NotImplementedError,
300+
match=r"ModelCheckpoint.save_checkpoint\(\)` was deprecated in v1.6 and is no longer supported as of 1.8.",
300301
):
301302
mc.save_checkpoint(trainer)

0 commit comments

Comments
 (0)