Skip to content

Error messages for removed DataModule hooks #15072

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Oct 11, 2022
1 change: 1 addition & 0 deletions src/pytorch_lightning/_graveyard/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pytorch_lightning._graveyard.core
import pytorch_lightning._graveyard.trainer
import pytorch_lightning._graveyard.training_type # noqa: F401
56 changes: 56 additions & 0 deletions src/pytorch_lightning/_graveyard/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any

from pytorch_lightning import LightningDataModule, LightningModule


def _use_amp(_: LightningModule) -> None:
# Remove in v2.0.0
raise AttributeError(
"`LightningModule.use_amp` was deprecated in v1.6 and is no longer accessible as of v1.8."
" Please use `Trainer.amp_backend`.",
)


def _use_amp_setter(_: LightningModule, __: bool) -> None:
# Remove in v2.0.0
raise AttributeError(
"`LightningModule.use_amp` was deprecated in v1.6 and is no longer accessible as of v1.8."
" Please use `Trainer.amp_backend`.",
)


def _on_save_checkpoint(_: LightningDataModule, __: Any) -> None:
# Remove in v2.0.0
raise NotImplementedError(
"`LightningDataModule.on_save_checkpoint` was deprecated in v1.6 and removed in v1.8."
" Use `state_dict` instead."
)


def _on_load_checkpoint(_: LightningDataModule, __: Any) -> None:
# Remove in v2.0.0
raise NotImplementedError(
"`LightningDataModule.on_load_checkpoint` was deprecated in v1.6 and removed in v1.8."
" Use `load_state_dict` instead."
)


# Properties/Attributes
LightningModule.use_amp = property(fget=_use_amp, fset=_use_amp_setter)

# Methods
LightningDataModule.on_save_checkpoint = _on_save_checkpoint
LightningDataModule.on_load_checkpoint = _on_load_checkpoint
18 changes: 18 additions & 0 deletions src/pytorch_lightning/trainer/configuration_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ def verify_loop_configurations(trainer: "pl.Trainer") -> None:
_check_on_epoch_start_end(model)
# TODO: Delete this check in v2.0
_check_on_pretrain_routine(model)
# TODO: Delete this check in v2.0
_check_unsupported_datamodule_hooks(trainer)


def __verify_train_val_loop_configuration(trainer: "pl.Trainer", model: "pl.LightningModule") -> None:
Expand Down Expand Up @@ -253,3 +255,19 @@ def _check_deprecated_callback_hooks(trainer: "pl.Trainer") -> None:
raise RuntimeError(
f"The `Callback.{hook}` hook was removed in v1.8. Please use `Callback.on_fit_start` instead."
)


def _check_unsupported_datamodule_hooks(trainer: "pl.Trainer") -> None:
datahook_selector = trainer._data_connector._datahook_selector
assert datahook_selector is not None

if is_overridden("on_save_checkpoint", datahook_selector.datamodule):
raise NotImplementedError(
"`LightningDataModule.on_save_checkpoint` was deprecated in v1.6 and removed in v1.8."
" Use `state_dict` instead."
)
if is_overridden("on_load_checkpoint", datahook_selector.datamodule):
raise NotImplementedError(
"`LightningDataModule.on_load_checkpoint` was deprecated in v1.6 and removed in v1.8."
" Use `load_state_dict` instead."
)
55 changes: 55 additions & 0 deletions tests/tests_pytorch/graveyard/test_core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest

from pytorch_lightning import Trainer
from pytorch_lightning.demos.boring_classes import BoringDataModule, BoringModel


def test_v2_0_0_unsupported_datamodule_on_save_load_checkpoint():
datamodule = BoringDataModule()
with pytest.raises(
NotImplementedError,
match="`LightningDataModule.on_save_checkpoint` was deprecated in v1.6 and removed in v1.8.",
):
datamodule.on_save_checkpoint({})

with pytest.raises(
NotImplementedError,
match="`LightningDataModule.on_load_checkpoint` was deprecated in v1.6 and removed in v1.8.",
):
datamodule.on_load_checkpoint({})

class OnSaveDataModule(BoringDataModule):
def on_save_checkpoint(self, checkpoint):
pass

class OnLoadDataModule(BoringDataModule):
def on_load_checkpoint(self, checkpoint):
pass

trainer = Trainer()
model = BoringModel()

with pytest.raises(
NotImplementedError,
match="`LightningDataModule.on_save_checkpoint` was deprecated in v1.6 and removed in v1.8.",
):
trainer.fit(model, OnSaveDataModule())

with pytest.raises(
NotImplementedError,
match="`LightningDataModule.on_load_checkpoint` was deprecated in v1.6 and removed in v1.8.",
):
trainer.fit(model, OnLoadDataModule())