Skip to content

Remove deprecated precision plugin checkpoint hooks #14833

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Sep 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Removed the deprecated way to set the distributed backend via the environment variable `PL_TORCH_DISTRIBUTED_BACKEND`, in favor of setting the `process_group_backend` in the strategy constructor ([#14693](https://github.com/Lightning-AI/lightning/pull/14693))


- Removed the deprecated device attributes `Trainer.{devices,gpus,num_gpus,ipus,tpu_cores}` in favor of the accelerator-agnostic `Trainer.num_devices` ([#14829](https://github.com/Lightning-AI/lightning/pull/14829))


- Removed the deprecated precision plugin checkpoint hooks `PrecisionPlugin.on_load_checkpoint` and `PrecisionPlugin.on_save_checkpoint` ([#14833](https://github.com/Lightning-AI/lightning/pull/14833))


- Removed the deprecated `Trainer.root_gpu` attribute in favor of `Trainer.strategy.root_device` ([#14829](https://github.com/Lightning-AI/lightning/pull/14829))


Expand Down
12 changes: 0 additions & 12 deletions src/pytorch_lightning/plugins/precision/precision_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,15 +269,3 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
state_dict: the precision plugin state returned by ``state_dict``.
"""
pass

def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
"""``PrecisionPlugin.on_save_checkpoint`` was deprecated in v1.6 and will be removed in v1.8.

Use ``state_dict`` instead.
"""

def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
"""``PrecisionPlugin.on_load_checkpoint`` was deprecated in v1.6 and will be removed in v1.8.

Use ``load_state_dict`` instead.
"""
16 changes: 0 additions & 16 deletions src/pytorch_lightning/trainer/configuration_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import pytorch_lightning as pl
from lightning_lite.utilities.warnings import PossibleUserWarning
from pytorch_lightning.accelerators.ipu import IPUAccelerator
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.strategies import DataParallelStrategy
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -50,8 +49,6 @@ def verify_loop_configurations(trainer: "pl.Trainer") -> None:
_check_deprecated_callback_hooks(trainer)
# TODO: Delete on_epoch_start/on_epoch_end hooks in v1.8
_check_on_epoch_start_end(model)
# TODO: Delete CheckpointHooks off PrecisionPlugin in v1.8
_check_precision_plugin_checkpoint_hooks(trainer)
# TODO: Delete on_pretrain_routine_start/end hooks in v1.8
_check_on_pretrain_routine(model)
# TODO: Delete CheckpointHooks off LightningDataModule in v1.8
Expand Down Expand Up @@ -266,19 +263,6 @@ def _check_deprecated_callback_hooks(trainer: "pl.Trainer") -> None:
)


def _check_precision_plugin_checkpoint_hooks(trainer: "pl.Trainer") -> None:
if is_overridden(method_name="on_save_checkpoint", instance=trainer.precision_plugin, parent=PrecisionPlugin):
rank_zero_deprecation(
"`PrecisionPlugin.on_save_checkpoint` was deprecated in"
" v1.6 and will be removed in v1.8. Use `state_dict` instead."
)
if is_overridden(method_name="on_load_checkpoint", instance=trainer.precision_plugin, parent=PrecisionPlugin):
rank_zero_deprecation(
"`PrecisionPlugin.on_load_checkpoint` was deprecated in"
" v1.6 and will be removed in v1.8. Use `load_state_dict` instead."
)


def _check_datamodule_checkpoint_hooks(trainer: "pl.Trainer") -> None:
if is_overridden(method_name="on_save_checkpoint", instance=trainer.datamodule):
rank_zero_deprecation(
Expand Down
29 changes: 0 additions & 29 deletions tests/tests_pytorch/deprecated_api/test_remove_1-8.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.demos.boring_classes import BoringDataModule, BoringModel
from pytorch_lightning.loggers import CSVLogger, Logger
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.profilers import AdvancedProfiler, SimpleProfiler
from pytorch_lightning.strategies.ipu import LightningIPUModule
from pytorch_lightning.trainer.configuration_validator import _check_datamodule_checkpoint_hooks
Expand Down Expand Up @@ -434,34 +433,6 @@ def _get_python_cprofile_total_duration(profile):
np.testing.assert_allclose(recorded_total_duration, expected_total_duration, rtol=0.2)


def test_v1_8_0_precision_plugin_checkpoint_hooks(tmpdir):
class PrecisionPluginSaveHook(PrecisionPlugin):
def on_save_checkpoint(self, checkpoint):
print("override on_save_checkpoint")

class PrecisionPluginLoadHook(PrecisionPlugin):
def on_load_checkpoint(self, checkpoint):
print("override on_load_checkpoint")

model = BoringModel()

precplugin_save = PrecisionPluginSaveHook()
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, plugins=[precplugin_save])
with pytest.deprecated_call(
match="`PrecisionPlugin.on_save_checkpoint` was deprecated in"
" v1.6 and will be removed in v1.8. Use `state_dict` instead."
):
trainer.fit(model)

precplugin_load = PrecisionPluginLoadHook()
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, plugins=[precplugin_load])
with pytest.deprecated_call(
match="`PrecisionPlugin.on_load_checkpoint` was deprecated in"
" v1.6 and will be removed in v1.8. Use `load_state_dict` instead."
):
trainer.fit(model)


def test_v1_8_0_datamodule_checkpointhooks():
class CustomBoringDataModuleSave(BoringDataModule):
def on_save_checkpoint(self, checkpoint):
Expand Down