Skip to content

Remove deprecated on_load/save_checkpoint behavior #14835

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Oct 10, 2022
7 changes: 7 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- HPC checkpoints are now loaded automatically only in slurm environment when no specific value for `ckpt_path` has been set ([#14911](https://github.com/Lightning-AI/lightning/pull/14911))


- The `Callback.on_load_checkpoint` now gets the full checkpoint dictionary and the `callback_state` argument was renamed `checkpoint` ([#14835](https://github.com/Lightning-AI/lightning/pull/14835))


### Deprecated

- Deprecated `LightningDeepSpeedModule` ([#14000](https://github.com/Lightning-AI/lightning/pull/14000))
Expand Down Expand Up @@ -302,6 +305,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed the deprecated `LightningDataModule.on_save/load_checkpoint` hooks ([#14909](https://github.com/Lightning-AI/lightning/pull/14909))


- Removed support for returning a value in `Callback.on_save_checkpoint` in favor of implementing `Callback.state_dict` ([#14835](https://github.com/Lightning-AI/lightning/pull/14835))



### Fixed

- Fixed an issue with `LightningLite.setup()` not setting the `.device` attribute correctly on the returned wrapper ([#14822](https://github.com/Lightning-AI/lightning/pull/14822))
Expand Down
25 changes: 3 additions & 22 deletions src/pytorch_lightning/callbacks/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,45 +277,26 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:

def on_save_checkpoint(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
) -> Optional[dict]:
) -> None:
r"""
Called when saving a checkpoint to give you a chance to store anything else you might want to save.

Args:
trainer: the current :class:`~pytorch_lightning.trainer.Trainer` instance.
pl_module: the current :class:`~pytorch_lightning.core.module.LightningModule` instance.
checkpoint: the checkpoint dictionary that will be saved.

Returns:
None or the callback state. Support for returning callback state will be removed in v1.8.

.. deprecated:: v1.6
Returning a value from this method was deprecated in v1.6 and will be removed in v1.8.
Implement ``Callback.state_dict`` instead to return state.
In v1.8 ``Callback.on_save_checkpoint`` can only return None.
"""

def on_load_checkpoint(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", callback_state: Dict[str, Any]
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
) -> None:
r"""
Called when loading a model checkpoint, use to reload state.

Args:
trainer: the current :class:`~pytorch_lightning.trainer.Trainer` instance.
pl_module: the current :class:`~pytorch_lightning.core.module.LightningModule` instance.
callback_state: the callback state returned by ``on_save_checkpoint``.

Note:
The ``on_load_checkpoint`` won't be called with an undefined state.
If your ``on_load_checkpoint`` hook behavior doesn't rely on a state,
you will still need to override ``on_save_checkpoint`` to return a ``dummy state``.

.. deprecated:: v1.6
This callback hook will change its signature and behavior in v1.8.
If you wish to load the state of the callback, use ``Callback.load_state_dict`` instead.
In v1.8 ``Callback.on_load_checkpoint(checkpoint)`` will receive the entire loaded
checkpoint dictionary instead of only the callback state from the checkpoint.
checkpoint: the full checkpoint dictionary that got loaded by the Trainer.
"""

def on_before_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", loss: Tensor) -> None:
Expand Down
4 changes: 1 addition & 3 deletions src/pytorch_lightning/callbacks/pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,9 +423,7 @@ def move_to_cpu(tensor: Tensor) -> Tensor:

return apply_to_collection(state_dict, Tensor, move_to_cpu)

def on_save_checkpoint(
self, trainer: "pl.Trainer", pl_module: LightningModule, checkpoint: Dict[str, Any]
) -> Optional[dict]:
def on_save_checkpoint(self, trainer: "pl.Trainer", pl_module: LightningModule, checkpoint: Dict[str, Any]) -> None:
if self._make_pruning_permanent:
rank_zero_debug("`ModelPruning.on_save_checkpoint`. Pruning is made permanent for this checkpoint")
# manually prune the weights so training can keep going with the same buffers
Expand Down
16 changes: 11 additions & 5 deletions src/pytorch_lightning/trainer/configuration_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect

import pytorch_lightning as pl
from lightning_lite.utilities.warnings import PossibleUserWarning
from pytorch_lightning.accelerators.ipu import IPUAccelerator
Expand Down Expand Up @@ -220,12 +222,16 @@ def _check_deprecated_callback_hooks(trainer: "pl.Trainer") -> None:
"The `on_before_accelerator_backend_setup` callback hook was deprecated in"
" v1.6 and will be removed in v1.8. Use `setup()` instead."
)
if is_overridden(method_name="on_load_checkpoint", instance=callback):
rank_zero_deprecation(
f"`{callback.__class__.__name__}.on_load_checkpoint` will change its signature and behavior in v1.8."

has_legacy_argument = "callback_state" in inspect.signature(callback.on_load_checkpoint).parameters
if is_overridden(method_name="on_load_checkpoint", instance=callback) and has_legacy_argument:
# TODO: Remove this error message in v2.0
raise TypeError(
f"`{callback.__class__.__name__}.on_load_checkpoint` has changed its signature and behavior in v1.8."
" If you wish to load the state of the callback, use `load_state_dict` instead."
" In v1.8 `on_load_checkpoint(..., checkpoint)` will receive the entire loaded"
" checkpoint dictionary instead of callback state."
" As of 1.8, `on_load_checkpoint(..., checkpoint)` receives the entire loaded"
" checkpoint dictionary instead of the callback state. To continue using this hook and avoid this error"
" message, rename the `callback_state` argument to `checkpoint`."
)

for hook, alternative_hook in (
Expand Down
25 changes: 9 additions & 16 deletions src/pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1354,11 +1354,7 @@ def _call_callbacks_state_dict(self) -> Dict[str, dict]:
return callback_state_dicts

def _call_callbacks_on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
"""Called when saving a model checkpoint, calls every callback's `on_save_checkpoint` hook.

Will be removed in v1.8: If state is returned, we insert the callback state into
``checkpoint["callbacks"][Callback.state_key]``. It overrides ``state_dict`` if already present.
"""
"""Called when saving a model checkpoint, calls every callback's `on_save_checkpoint` hook."""
pl_module = self.lightning_module
if pl_module:
prev_fx_name = pl_module._current_fx_name
Expand All @@ -1367,13 +1363,13 @@ def _call_callbacks_on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None
for callback in self.callbacks:
with self.profiler.profile(f"[Callback]{callback.state_key}.on_save_checkpoint"):
state = callback.on_save_checkpoint(self, self.lightning_module, checkpoint)
if state:
rank_zero_deprecation(
f"Returning a value from `{callback.__class__.__name__}.on_save_checkpoint` is deprecated in v1.6"
" and will be removed in v1.8. Please override `Callback.state_dict`"
" to return state to be saved."
if state is not None:
# TODO: Remove this error message in v2.0
raise ValueError(
f"Returning a value from `{callback.__class__.__name__}.on_save_checkpoint` was deprecated in v1.6"
f" and is no longer supported as of v1.8. Please override `Callback.state_dict` to return state"
f" to be saved."
)
checkpoint["callbacks"][callback.state_key] = state

if pl_module:
# restore current_fx when nested context
Expand Down Expand Up @@ -1406,11 +1402,8 @@ def _call_callbacks_on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None
)

for callback in self.callbacks:
state = callback_states.get(callback.state_key, callback_states.get(callback._legacy_state_key))
if state:
state = deepcopy(state)
with self.profiler.profile(f"[Callback]{callback.state_key}.on_load_checkpoint"):
callback.on_load_checkpoint(self, self.lightning_module, state)
with self.profiler.profile(f"[Callback]{callback.state_key}.on_load_checkpoint"):
callback.on_load_checkpoint(self, self.lightning_module, checkpoint)

if pl_module:
# restore current_fx when nested context
Expand Down
53 changes: 0 additions & 53 deletions tests/tests_pytorch/deprecated_api/test_remove_1-8.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,56 +200,3 @@ def test_deprecated_mc_save_checkpoint():
match=r"ModelCheckpoint.save_checkpoint\(\)` was deprecated in v1.6"
):
mc.save_checkpoint(trainer)


def test_v1_8_0_callback_on_load_checkpoint_hook(tmpdir):
class TestCallbackLoadHook(Callback):
def on_load_checkpoint(self, trainer, pl_module, callback_state):
print("overriding on_load_checkpoint")

model = BoringModel()
trainer = Trainer(
callbacks=[TestCallbackLoadHook()],
max_epochs=1,
fast_dev_run=True,
enable_progress_bar=False,
logger=False,
default_root_dir=tmpdir,
)
with pytest.deprecated_call(
match="`TestCallbackLoadHook.on_load_checkpoint` will change its signature and behavior in v1.8."
" If you wish to load the state of the callback, use `load_state_dict` instead."
r" In v1.8 `on_load_checkpoint\(..., checkpoint\)` will receive the entire loaded"
" checkpoint dictionary instead of callback state."
):
trainer.fit(model)


def test_v1_8_0_callback_on_save_checkpoint_hook(tmpdir):
class TestCallbackSaveHookReturn(Callback):
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
return {"returning": "on_save_checkpoint"}

class TestCallbackSaveHookOverride(Callback):
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
print("overriding without returning")

model = BoringModel()
trainer = Trainer(
callbacks=[TestCallbackSaveHookReturn()],
max_epochs=1,
fast_dev_run=True,
enable_progress_bar=False,
logger=False,
default_root_dir=tmpdir,
)
trainer.fit(model)
with pytest.deprecated_call(
match="Returning a value from `TestCallbackSaveHookReturn.on_save_checkpoint` is deprecated in v1.6"
" and will be removed in v1.8. Please override `Callback.state_dict`"
" to return state to be saved."
):
trainer.save_checkpoint(tmpdir + "/path.ckpt")

trainer.callbacks = [TestCallbackSaveHookOverride()]
trainer.save_checkpoint(tmpdir + "/pathok.ckpt")
54 changes: 53 additions & 1 deletion tests/tests_pytorch/deprecated_api/test_remove_2-0.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import pytest

import pytorch_lightning
from pytorch_lightning import Trainer
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.demos.boring_classes import BoringModel
from tests_pytorch.callbacks.test_callbacks import OldStatefulCallback
from tests_pytorch.helpers.runif import RunIf
Expand Down Expand Up @@ -84,3 +84,55 @@ def test_v2_0_resume_from_checkpoint_trainer_constructor(tmpdir):
trainer = Trainer(resume_from_checkpoint="trainer_arg_path")
with pytest.raises(FileNotFoundError, match="Checkpoint at fit_arg_ckpt_path not found. Aborting training."):
trainer.fit(model, ckpt_path="fit_arg_ckpt_path")


def test_v1_2_0_callback_on_load_checkpoint_hook(tmpdir):
class TestCallbackLoadHook(Callback):
def on_load_checkpoint(self, trainer, pl_module, callback_state):
print("overriding on_load_checkpoint")

model = BoringModel()
trainer = Trainer(
callbacks=[TestCallbackLoadHook()],
max_epochs=1,
fast_dev_run=True,
enable_progress_bar=False,
logger=False,
default_root_dir=tmpdir,
)
with pytest.raises(
TypeError, match="`TestCallbackLoadHook.on_load_checkpoint` has changed its signature and behavior in v1.8."
):
trainer.fit(model)


def test_v1_2_0_callback_on_save_checkpoint_hook(tmpdir):
class TestCallbackSaveHookReturn(Callback):
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
return {"returning": "on_save_checkpoint"}

class TestCallbackSaveHookOverride(Callback):
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
print("overriding without returning")

model = BoringModel()
trainer = Trainer(
callbacks=[TestCallbackSaveHookReturn()],
max_epochs=1,
fast_dev_run=True,
enable_progress_bar=False,
logger=False,
default_root_dir=tmpdir,
)
trainer.fit(model)
with pytest.raises(
ValueError,
match=(
"Returning a value from `TestCallbackSaveHookReturn.on_save_checkpoint` was deprecated in v1.6 and is"
" no longer supported as of v1.8"
),
):
trainer.save_checkpoint(tmpdir + "/path.ckpt")

trainer.callbacks = [TestCallbackSaveHookOverride()]
trainer.save_checkpoint(tmpdir + "/pathok.ckpt")