Skip to content

Commit 1e5411b

Browse files
NeoKishawaelchliotaj
authored
Removed the deprecated datamodule_checkpointhooks (#14909)
Co-authored-by: awaelchli <[email protected]> Co-authored-by: otaj <[email protected]>
1 parent 4c43e57 commit 1e5411b

File tree

8 files changed

+17
-82
lines changed

8 files changed

+17
-82
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
274274

275275
- Removed the deprecated `Trainer.{validated,tested,predicted}_ckpt_path` ([#14897](https://github.com/Lightning-AI/lightning/pull/14897))
276276

277+
- Removed the deprecated `LightningDataModule.on_save/load_checkpoint` hooks ([#14909](https://github.com/Lightning-AI/lightning/pull/14909))
277278

278279
### Fixed
279280

src/pytorch_lightning/core/datamodule.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
import pytorch_lightning as pl
2222
from lightning_lite.utilities.types import _PATH
23-
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks
23+
from pytorch_lightning.core.hooks import DataHooks
2424
from pytorch_lightning.core.mixins import HyperparametersMixin
2525
from pytorch_lightning.core.saving import _load_from_checkpoint
2626
from pytorch_lightning.utilities.argparse import (
@@ -32,7 +32,7 @@
3232
from pytorch_lightning.utilities.types import _ADD_ARGPARSE_RETURN, EVAL_DATALOADERS, TRAIN_DATALOADERS
3333

3434

35-
class LightningDataModule(CheckpointHooks, DataHooks, HyperparametersMixin):
35+
class LightningDataModule(DataHooks, HyperparametersMixin):
3636
"""A DataModule standardizes the training, val, test splits, data preparation and transforms. The main
3737
advantage is consistent data splits, data preparation and transforms across models.
3838

src/pytorch_lightning/core/saving.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,8 +222,9 @@ def _load_state(
222222

223223
obj = cls(**_cls_kwargs)
224224

225-
# give model a chance to load something
226-
obj.on_load_checkpoint(checkpoint)
225+
if isinstance(obj, pl.LightningModule):
226+
# give model a chance to load something
227+
obj.on_load_checkpoint(checkpoint)
227228

228229
if isinstance(obj, pl.LightningDataModule):
229230
if obj.__class__.__qualname__ in checkpoint:

src/pytorch_lightning/trainer/configuration_validator.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,6 @@ def verify_loop_configurations(trainer: "pl.Trainer") -> None:
5151
_check_on_epoch_start_end(model)
5252
# TODO: Delete on_pretrain_routine_start/end hooks in v1.8
5353
_check_on_pretrain_routine(model)
54-
# TODO: Delete CheckpointHooks off LightningDataModule in v1.8
55-
_check_datamodule_checkpoint_hooks(trainer)
5654

5755

5856
def __verify_train_val_loop_configuration(trainer: "pl.Trainer", model: "pl.LightningModule") -> None:
@@ -261,16 +259,3 @@ def _check_deprecated_callback_hooks(trainer: "pl.Trainer") -> None:
261259
f"The `Callback.{hook}` hook has been deprecated in v1.6 and"
262260
" will be removed in v1.8. Please use `Callback.on_fit_start` instead."
263261
)
264-
265-
266-
def _check_datamodule_checkpoint_hooks(trainer: "pl.Trainer") -> None:
267-
if is_overridden(method_name="on_save_checkpoint", instance=trainer.datamodule):
268-
rank_zero_deprecation(
269-
"`LightningDataModule.on_save_checkpoint` was deprecated in"
270-
" v1.6 and will be removed in v1.8. Use `state_dict` instead."
271-
)
272-
if is_overridden(method_name="on_load_checkpoint", instance=trainer.datamodule):
273-
rank_zero_deprecation(
274-
"`LightningDataModule.on_load_checkpoint` was deprecated in"
275-
" v1.6 and will be removed in v1.8. Use `load_state_dict` instead."
276-
)

src/pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -243,12 +243,10 @@ def restore_datamodule(self) -> None:
243243
return
244244

245245
datamodule = self.trainer.datamodule
246-
if datamodule is not None:
247-
self.trainer._call_lightning_datamodule_hook("on_load_checkpoint", self._loaded_checkpoint)
248-
if datamodule.__class__.__qualname__ in self._loaded_checkpoint:
249-
self.trainer._call_lightning_datamodule_hook(
250-
"load_state_dict", self._loaded_checkpoint[datamodule.__class__.__qualname__]
251-
)
246+
if datamodule is not None and datamodule.__class__.__qualname__ in self._loaded_checkpoint:
247+
self.trainer._call_lightning_datamodule_hook(
248+
"load_state_dict", self._loaded_checkpoint[datamodule.__class__.__qualname__]
249+
)
252250

253251
def restore_model(self) -> None:
254252
"""Restores a model's weights from a PyTorch Lightning checkpoint.
@@ -519,9 +517,6 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
519517
# will be removed in v1.8
520518
self.trainer._call_callbacks_on_save_checkpoint(checkpoint)
521519
self.trainer._call_lightning_module_hook("on_save_checkpoint", checkpoint)
522-
if datamodule is not None:
523-
self.trainer._call_lightning_datamodule_hook("on_save_checkpoint", checkpoint)
524-
525520
return checkpoint
526521

527522
def save_checkpoint(

tests/tests_pytorch/core/test_datamodules.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -202,13 +202,6 @@ def state_dict(self) -> Dict[str, Any]:
202202
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
203203
self.my_state_dict = state_dict
204204

205-
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
206-
checkpoint[self.__class__.__qualname__].update({"on_save": "update"})
207-
208-
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
209-
self.checkpoint_state = checkpoint.get(self.__class__.__qualname__).copy()
210-
checkpoint[self.__class__.__qualname__].pop("on_save")
211-
212205
reset_seed()
213206
dm = CustomBoringDataModule()
214207
model = CustomBoringModel()
@@ -223,21 +216,15 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
223216
)
224217

225218
# fit model
226-
with pytest.deprecated_call(
227-
match="`LightningDataModule.on_save_checkpoint` was deprecated in"
228-
" v1.6 and will be removed in v1.8. Use `state_dict` instead."
229-
):
230-
trainer.fit(model, datamodule=dm)
231-
assert trainer.state.finished, f"Training failed with {trainer.state}"
219+
trainer.fit(model, datamodule=dm)
232220
checkpoint_path = list(trainer.checkpoint_callback.best_k_models.keys())[0]
233221
checkpoint = torch.load(checkpoint_path)
234222
assert dm.__class__.__qualname__ in checkpoint
235-
assert checkpoint[dm.__class__.__qualname__] == {"my": "state_dict", "on_save": "update"}
223+
assert checkpoint[dm.__class__.__qualname__] == {"my": "state_dict"}
236224

237225
for trainer_fn in TrainerFn:
238226
trainer.state.fn = trainer_fn
239227
trainer._restore_modules_and_callbacks(checkpoint_path)
240-
assert dm.checkpoint_state == {"my": "state_dict", "on_save": "update"}
241228
assert dm.my_state_dict == {"my": "state_dict"}
242229

243230

@@ -510,7 +497,6 @@ def state_dict(self):
510497
"[LightningDataModule]CustomBoringDataModule.prepare_data",
511498
"[LightningDataModule]CustomBoringDataModule.setup",
512499
"[LightningDataModule]CustomBoringDataModule.state_dict",
513-
"[LightningDataModule]CustomBoringDataModule.on_save_checkpoint",
514500
"[LightningDataModule]CustomBoringDataModule.teardown",
515501
]
516502
for key in keys:
@@ -527,7 +513,6 @@ def state_dict(self):
527513
keys = [
528514
"[LightningDataModule]CustomBoringDataModule.prepare_data",
529515
"[LightningDataModule]CustomBoringDataModule.setup",
530-
"[LightningDataModule]CustomBoringDataModule.on_load_checkpoint",
531516
"[LightningDataModule]CustomBoringDataModule.load_state_dict",
532517
"[LightningDataModule]CustomBoringDataModule.teardown",
533518
]

tests/tests_pytorch/deprecated_api/test_remove_1-8.py

Lines changed: 1 addition & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,12 @@
1313
# limitations under the License.
1414
"""Test deprecated functionality which will be removed in v1.8.0."""
1515
from unittest import mock
16-
from unittest.mock import Mock
1716

1817
import pytest
1918

2019
from pytorch_lightning import Callback, Trainer
2120
from pytorch_lightning.callbacks import ModelCheckpoint
22-
from pytorch_lightning.demos.boring_classes import BoringDataModule, BoringModel
23-
from pytorch_lightning.trainer.configuration_validator import _check_datamodule_checkpoint_hooks
21+
from pytorch_lightning.demos.boring_classes import BoringModel
2422

2523

2624
def test_v1_8_0_on_init_start_end(tmpdir):
@@ -230,32 +228,6 @@ def on_pretrain_routine_end(self, trainer, pl_module):
230228
trainer.fit(model)
231229

232230

233-
def test_v1_8_0_datamodule_checkpointhooks():
234-
class CustomBoringDataModuleSave(BoringDataModule):
235-
def on_save_checkpoint(self, checkpoint):
236-
print("override on_save_checkpoint")
237-
238-
class CustomBoringDataModuleLoad(BoringDataModule):
239-
def on_load_checkpoint(self, checkpoint):
240-
print("override on_load_checkpoint")
241-
242-
trainer = Mock()
243-
244-
trainer.datamodule = CustomBoringDataModuleSave()
245-
with pytest.deprecated_call(
246-
match="`LightningDataModule.on_save_checkpoint` was deprecated in"
247-
" v1.6 and will be removed in v1.8. Use `state_dict` instead."
248-
):
249-
_check_datamodule_checkpoint_hooks(trainer)
250-
251-
trainer.datamodule = CustomBoringDataModuleLoad()
252-
with pytest.deprecated_call(
253-
match="`LightningDataModule.on_load_checkpoint` was deprecated in"
254-
" v1.6 and will be removed in v1.8. Use `load_state_dict` instead."
255-
):
256-
_check_datamodule_checkpoint_hooks(trainer)
257-
258-
259231
def test_deprecated_mc_save_checkpoint():
260232
mc = ModelCheckpoint()
261233
trainer = Trainer()

tests/tests_pytorch/models/test_hooks.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -955,7 +955,6 @@ def test_trainer_datamodule_hook_system(tmpdir):
955955
dict(name="val_dataloader"),
956956
dict(name="train_dataloader"),
957957
dict(name="state_dict"),
958-
dict(name="on_save_checkpoint", args=(ANY,)),
959958
dict(name="teardown", kwargs=dict(stage="fit")),
960959
]
961960
assert called == expected
@@ -1022,13 +1021,10 @@ def state_dict(self):
10221021
}
10231022

10241023
assert lm_called == [dict(name="on_save_checkpoint", args=(saved_ckpt,))]
1025-
assert ldm_called == [dict(name="state_dict"), dict(name="on_save_checkpoint", args=(saved_ckpt,))]
1024+
assert ldm_called == [dict(name="state_dict")]
10261025

10271026
lm_called, ldm_called = [], []
1028-
model = HookedModel.load_from_checkpoint(ckpt_path, called=lm_called)
1029-
datamodule = CustomHookedDataModule.load_from_checkpoint(ckpt_path, called=ldm_called)
1027+
_ = HookedModel.load_from_checkpoint(ckpt_path, called=lm_called)
1028+
_ = CustomHookedDataModule.load_from_checkpoint(ckpt_path, called=ldm_called)
10301029
assert lm_called == [dict(name="on_load_checkpoint", args=({**saved_ckpt, "hyper_parameters": ANY},))]
1031-
assert ldm_called == [
1032-
dict(name="on_load_checkpoint", args=({**saved_ckpt, "datamodule_hyper_parameters": ANY},)),
1033-
dict(name="load_state_dict", args=(saved_ckpt[datamodule_state_dict_key],)),
1034-
]
1030+
assert ldm_called == [dict(name="load_state_dict", args=(saved_ckpt[datamodule_state_dict_key],))]

0 commit comments

Comments
 (0)