Skip to content

Add remove_checkpoint to CheckpointIO plugin to simplify ModelCheckpo… #9373

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Sep 10, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `on_exception` callback hook ([#9183](https://github.com/PyTorchLightning/pytorch-lightning/pull/9183))


- Add a warning to deepspeed when inferring batch size ([#9221](https://github.com/PyTorchLightning/pytorch-lightning/pull/9221))
- Added a warning to deepspeed when inferring batch size ([#9221](https://github.com/PyTorchLightning/pytorch-lightning/pull/9221))


- Added `inference_mode` for evaluation and prediction ([8813](https://github.com/PyTorchLightning/pytorch-lightning/pull/8813))
- Added `inference_mode` for evaluation and prediction ([#8813](https://github.com/PyTorchLightning/pytorch-lightning/pull/8813))


- Added `remove_checkpoint` to `CheckpointIO` plugin by moving the responsibility from `ModelCheckpoint` Callback ([#9373](https://github.com/PyTorchLightning/pytorch-lightning/pull/9373))


### Changed
Expand Down
34 changes: 8 additions & 26 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,19 +486,6 @@ def __init_triggers(
def every_n_epochs(self) -> Optional[int]:
return self._every_n_epochs

def _del_model(self, trainer: "pl.Trainer", filepath: str) -> None:
if trainer.should_rank_save_checkpoint and self._fs.exists(filepath):
self._fs.rm(filepath, recursive=True)
log.debug(f"Removed checkpoint: {filepath}")

def _save_model(self, trainer: "pl.Trainer", filepath: str) -> None:
# make paths
if trainer.should_rank_save_checkpoint:
self._fs.makedirs(os.path.dirname(filepath), exist_ok=True)

# delegate the saving to the trainer
trainer.save_checkpoint(filepath, self.save_weights_only)

def check_monitor_top_k(self, trainer: "pl.Trainer", current: Optional[torch.Tensor] = None) -> bool:
if current is None:
return False
Expand Down Expand Up @@ -671,10 +658,10 @@ def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[
filepath = self._format_checkpoint_name(self.CHECKPOINT_NAME_LAST, monitor_candidates)
filepath = os.path.join(self.dirpath, f"{filepath}{self.FILE_EXTENSION}")

self._save_model(trainer, filepath)
trainer.save_checkpoint(filepath, self.save_weights_only)

if self.last_model_path and self.last_model_path != filepath and trainer.should_rank_save_checkpoint:
self._del_model(trainer, self.last_model_path)
if self.last_model_path and self.last_model_path != filepath:
trainer.training_type_plugin.remove_checkpoint(self.last_model_path)

self.last_model_path = filepath

Expand All @@ -696,15 +683,10 @@ def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidate
return

filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, trainer)
self._save_model(trainer, filepath)
trainer.save_checkpoint(filepath, self.save_weights_only)

if (
self.save_top_k == 1
and self.best_model_path
and self.best_model_path != filepath
and trainer.should_rank_save_checkpoint
):
self._del_model(trainer, self.best_model_path)
if self.save_top_k == 1 and self.best_model_path and self.best_model_path != filepath:
trainer.training_type_plugin.remove_checkpoint(self.best_model_path)

self.best_model_path = filepath

Expand Down Expand Up @@ -748,10 +730,10 @@ def _update_best_and_save(
f"Epoch {epoch:d}, global step {step:d}: {self.monitor} reached {current:0.5f}"
f' (best {self.best_model_score:0.5f}), saving model to "{filepath}" as top {k}'
)
self._save_model(trainer, filepath)
trainer.save_checkpoint(filepath, self.save_weights_only)

if del_filepath is not None and filepath != del_filepath:
self._del_model(trainer, del_filepath)
trainer.training_type_plugin.remove_checkpoint(del_filepath)

def to_yaml(self, filepath: Optional[Union[str, Path]] = None) -> None:
"""Saves the `best_k_models` dict containing the checkpoint paths with the corresponding scores to a YAML
Expand Down
8 changes: 8 additions & 0 deletions pytorch_lightning/plugins/io/checkpoint_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,11 @@ def load_checkpoint(self, path: _PATH, storage_options: Optional[Any] = None) ->

Returns: The loaded checkpoint.
"""

@abstractmethod
def remove_checkpoint(self, path: _PATH) -> None:
"""Remove checkpoint filepath from the filesystem.

Args:
path: Path to checkpoint
"""
17 changes: 17 additions & 0 deletions pytorch_lightning/plugins/io/torch_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
from typing import Any, Callable, Dict, Optional

import pytorch_lightning as pl
Expand All @@ -20,12 +22,16 @@
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.types import _PATH

log = logging.getLogger(__name__)


class TorchCheckpointIO(CheckpointIO):
"""CheckpointIO that utilizes :func:`torch.save` and :func:`torch.load` to save and load checkpoints
respectively, common for most use cases."""

def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None:
fs = get_filesystem(path)
fs.makedirs(os.path.dirname(path), exist_ok=True)
try:
# write the checkpoint dictionary on the file
atomic_save(checkpoint, path)
Expand Down Expand Up @@ -60,3 +66,14 @@ def load_checkpoint(
raise FileNotFoundError(f"Checkpoint at {path} not found. Aborting training.")

return pl_load(path, map_location=map_location)

def remove_checkpoint(self, path: _PATH) -> None:
"""Remove checkpoint file from the filesystem.

Args:
path: Path to checkpoint
"""
fs = get_filesystem(path)
if fs.exists(path):
fs.rm(path, recursive=True)
log.debug(f"Removed checkpoint: {path}")
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,15 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None:
if self.should_rank_save_checkpoint:
return self.checkpoint_io.save_checkpoint(checkpoint, filepath)

def remove_checkpoint(self, filepath: str) -> None:
"""Remove checkpoint filepath from the filesystem.

Args:
filepath: Path to checkpoint
"""
if self.should_rank_save_checkpoint:
return self.checkpoint_io.remove_checkpoint(filepath)

@contextlib.contextmanager
def model_sharded_context(self) -> Generator:
"""Provide hook to create modules in a distributed aware context. This is useful for when we'd like to
Expand Down
14 changes: 8 additions & 6 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,17 +760,18 @@ def test_default_checkpoint_behavior(tmpdir):
default_root_dir=tmpdir, max_epochs=3, progress_bar_refresh_rate=0, limit_train_batches=5, limit_val_batches=5
)

with patch.object(ModelCheckpoint, "_save_model", wraps=trainer.checkpoint_callback._save_model) as save_mock:
with patch.object(trainer, "save_checkpoint", wraps=trainer.save_checkpoint) as save_mock:
trainer.fit(model)
results = trainer.test()

assert len(results) == 1
save_dir = tmpdir / "lightning_logs" / "version_0" / "checkpoints"
save_weights_only = trainer.checkpoint_callback.save_weights_only
save_mock.assert_has_calls(
[
call(trainer, save_dir / "epoch=0-step=4.ckpt"),
call(trainer, save_dir / "epoch=1-step=9.ckpt"),
call(trainer, save_dir / "epoch=2-step=14.ckpt"),
call(save_dir / "epoch=0-step=4.ckpt", save_weights_only),
call(save_dir / "epoch=1-step=9.ckpt", save_weights_only),
call(save_dir / "epoch=2-step=14.ckpt", save_weights_only),
]
)
ckpts = os.listdir(save_dir)
Expand Down Expand Up @@ -852,18 +853,19 @@ def validation_epoch_end(self, outputs):
model = CurrentModel()

callback = ModelCheckpoint(monitor="abc", mode=mode, save_top_k=1, dirpath=tmpdir)
callback._save_model = MagicMock()

trainer = Trainer(
callbacks=[callback],
default_root_dir=tmpdir,
val_check_interval=1.0,
max_epochs=len(monitor),
)
trainer.save_checkpoint = MagicMock()

trainer.fit(model)

# check that last one is also the best one
assert callback._save_model.call_count == len(monitor)
assert trainer.save_checkpoint.call_count == len(monitor)
assert mode == "min" and callback.best_model_score == 5 or mode == "max" and callback.best_model_score == 8


Expand Down
17 changes: 13 additions & 4 deletions tests/plugins/test_checkpoint_io_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Any, Dict, Optional
from unittest.mock import MagicMock

Expand All @@ -33,6 +34,9 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_optio
def load_checkpoint(self, path: _PATH, storage_options: Optional[Any] = None) -> Dict[str, Any]:
return torch.load(path)

def remove_checkpoint(self, path: _PATH) -> None:
os.remove(path)


def test_checkpoint_plugin_called(tmpdir):
"""Ensure that the custom checkpoint IO plugin and torch checkpoint IO plugin is called when saving/loading."""
Expand All @@ -47,10 +51,13 @@ def test_checkpoint_plugin_called(tmpdir):
default_root_dir=tmpdir,
plugins=SingleDevicePlugin(device, checkpoint_io=checkpoint_plugin),
callbacks=ck,
max_epochs=1,
max_epochs=2,
)
trainer.fit(model)
assert checkpoint_plugin.save_checkpoint.call_count == 3

assert checkpoint_plugin.save_checkpoint.call_count == 5
assert checkpoint_plugin.remove_checkpoint.call_count == 1

trainer.test(model, ckpt_path=ck.last_model_path)
checkpoint_plugin.load_checkpoint.assert_called_with(tmpdir / "last.ckpt")

Expand All @@ -63,10 +70,12 @@ def test_checkpoint_plugin_called(tmpdir):
default_root_dir=tmpdir,
plugins=[SingleDevicePlugin(device), checkpoint_plugin],
callbacks=ck,
max_epochs=1,
max_epochs=2,
)
trainer.fit(model)
assert checkpoint_plugin.save_checkpoint.call_count == 3

assert checkpoint_plugin.save_checkpoint.call_count == 5
assert checkpoint_plugin.remove_checkpoint.call_count == 1

trainer.test(model, ckpt_path=ck.last_model_path)
checkpoint_plugin.load_checkpoint.assert_called_once()
Expand Down