Skip to content

Commit 1b4212a

Browse files
awaelchliBorda
authored andcommitted
Error messages for removed DataModule hooks (#15072)
Co-authored-by: Jirka Borovec <[email protected]>
1 parent 8704cc8 commit 1b4212a

File tree

5 files changed

+131
-20
lines changed

5 files changed

+131
-20
lines changed

src/pytorch_lightning/_graveyard/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import pytorch_lightning._graveyard.callbacks
16+
import pytorch_lightning._graveyard.core
1617
import pytorch_lightning._graveyard.loggers
1718
import pytorch_lightning._graveyard.trainer
1819
import pytorch_lightning._graveyard.training_type # noqa: F401
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import Any
15+
16+
from pytorch_lightning import LightningDataModule
17+
18+
19+
def _on_save_checkpoint(_: LightningDataModule, __: Any) -> None:
20+
# TODO: Remove in v2.0.0
21+
raise NotImplementedError(
22+
"`LightningDataModule.on_save_checkpoint` was deprecated in v1.6 and removed in v1.8."
23+
" Use `state_dict` instead."
24+
)
25+
26+
27+
def _on_load_checkpoint(_: LightningDataModule, __: Any) -> None:
28+
# TODO: Remove in v2.0.0
29+
raise NotImplementedError(
30+
"`LightningDataModule.on_load_checkpoint` was deprecated in v1.6 and removed in v1.8."
31+
" Use `load_state_dict` instead."
32+
)
33+
34+
35+
# Methods
36+
LightningDataModule.on_save_checkpoint = _on_save_checkpoint
37+
LightningDataModule.on_load_checkpoint = _on_load_checkpoint

src/pytorch_lightning/_graveyard/trainer.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -40,154 +40,154 @@ def __init__(self) -> None:
4040

4141

4242
def _gpus(_: Trainer) -> None:
43-
# Remove in v2.0.0
43+
# TODO: Remove in v2.0.0
4444
raise AttributeError(
4545
"`Trainer.gpus` was deprecated in v1.6 and is no longer accessible as of v1.8."
4646
" Please use `Trainer.num_devices` or `Trainer.device_ids` to get device information instead."
4747
)
4848

4949

5050
def _root_gpu(_: Trainer) -> None:
51-
# Remove in v2.0.0
51+
# TODO: Remove in v2.0.0
5252
raise AttributeError(
5353
"`Trainer.root_gpu` was deprecated in v1.6 and is no longer accessible as of v1.8."
5454
" Please use `Trainer.strategy.root_device.index` instead."
5555
)
5656

5757

5858
def _tpu_cores(_: Trainer) -> None:
59-
# Remove in v2.0.0
59+
# TODO: Remove in v2.0.0
6060
raise AttributeError(
6161
"`Trainer.tpu_cores` was deprecated in v1.6 and is no longer accessible as of v1.8."
6262
" Please use `Trainer.num_devices` instead."
6363
)
6464

6565

6666
def _ipus(_: Trainer) -> None:
67-
# Remove in v2.0.0
67+
# TODO: Remove in v2.0.0
6868
raise AttributeError(
6969
"`Trainer.ipus` was deprecated in v1.6 and is no longer accessible as of v1.8."
7070
" Please use `Trainer.num_devices` instead."
7171
)
7272

7373

7474
def _num_gpus(_: Trainer) -> None:
75-
# Remove in v2.0.0
75+
# TODO: Remove in v2.0.0
7676
raise AttributeError(
7777
"`Trainer.num_gpus` was deprecated in v1.6 and is no longer accessible as of v1.8."
7878
" Please use `Trainer.num_devices` instead."
7979
)
8080

8181

8282
def _devices(_: Trainer) -> None:
83-
# Remove in v2.0.0
83+
# TODO: Remove in v2.0.0
8484
raise AttributeError(
8585
"`Trainer.devices` was deprecated in v1.6 and is no longer accessible as of v1.8."
8686
" Please use `Trainer.num_devices` or `Trainer.device_ids` to get device information instead."
8787
)
8888

8989

9090
def _use_amp(_: Trainer) -> None:
91-
# Remove in v2.0.0
91+
# TODO: Remove in v2.0.0
9292
raise AttributeError(
9393
"`Trainer.use_amp` was deprecated in v1.6 and is no longer accessible as of v1.8."
9494
" Please use `Trainer.amp_backend`.",
9595
)
9696

9797

9898
def _weights_save_path(_: Trainer) -> None:
99-
# Remove in v2.0.0
99+
# TODO: Remove in v2.0.0
100100
raise AttributeError("`Trainer.weights_save_path` was deprecated in v1.6 and is no longer accessible as of v1.8.")
101101

102102

103103
def _lightning_optimizers(_: Trainer) -> None:
104-
# Remove in v2.0.0
104+
# TODO: Remove in v2.0.0
105105
raise AttributeError(
106106
"`Trainer.lightning_optimizers` was deprecated in v1.6 and is no longer accessible as of v1.8."
107107
)
108108

109109

110110
def _should_rank_save_checkpoint(_: Trainer) -> None:
111-
# Remove in v2.0.0
111+
# TODO: Remove in v2.0.0
112112
raise AttributeError(
113113
"`Trainer.should_rank_save_checkpoint` was deprecated in v1.6 and is no longer accessible as of v1.8.",
114114
)
115115

116116

117117
def _validated_ckpt_path(_: Trainer) -> None:
118-
# Remove in v2.0.0
118+
# TODO: Remove in v2.0.0
119119
raise AttributeError(
120120
"The `Trainer.validated_ckpt_path` was deprecated in v1.6 and is no longer accessible as of v1.8."
121121
" Please use `Trainer.ckpt_path` instead."
122122
)
123123

124124

125125
def _validated_ckpt_path_setter(_: Trainer, __: Optional[str]) -> None:
126-
# Remove in v2.0.0
126+
# TODO: Remove in v2.0.0
127127
raise AttributeError(
128128
"The `Trainer.validated_ckpt_path` was deprecated in v1.6 and is no longer accessible as of v1.8."
129129
" Please use `Trainer.ckpt_path` instead."
130130
)
131131

132132

133133
def _tested_ckpt_path(_: Trainer) -> None:
134-
# Remove in v2.0.0
134+
# TODO: Remove in v2.0.0
135135
raise AttributeError(
136136
"The `Trainer.tested_ckpt_path` was deprecated in v1.6 and is no longer accessible as of v1.8."
137137
" Please use `Trainer.ckpt_path` instead."
138138
)
139139

140140

141141
def _tested_ckpt_path_setter(_: Trainer, __: Optional[str]) -> None:
142-
# Remove in v2.0.0
142+
# TODO: Remove in v2.0.0
143143
raise AttributeError(
144144
"The `Trainer.tested_ckpt_path` was deprecated in v1.6 and is no longer accessible as of v1.8."
145145
" Please use `Trainer.ckpt_path` instead."
146146
)
147147

148148

149149
def _predicted_ckpt_path(_: Trainer) -> None:
150-
# Remove in v2.0.0
150+
# TODO: Remove in v2.0.0
151151
raise AttributeError(
152152
"The `Trainer.predicted_ckpt_path` was deprecated in v1.6 and is no longer accessible as of v1.8."
153153
" Please use `Trainer.ckpt_path` instead."
154154
)
155155

156156

157157
def _predicted_ckpt_path_setter(_: Trainer, __: Optional[str]) -> None:
158-
# Remove in v2.0.0
158+
# TODO: Remove in v2.0.0
159159
raise AttributeError(
160160
"The `Trainer.predicted_ckpt_path` was deprecated in v1.6 and is no longer accessible as of v1.8."
161161
" Please use `Trainer.ckpt_path` instead."
162162
)
163163

164164

165165
def _verbose_evaluate(_: Trainer) -> None:
166-
# Remove in v2.0.0
166+
# TODO: Remove in v2.0.0
167167
raise AttributeError(
168168
"The `Trainer.verbose_evaluate` was deprecated in v1.6 and is no longer accessible as of v1.8."
169169
" Please use `trainer.{validate,test}_loop.verbose` instead.",
170170
)
171171

172172

173173
def _verbose_evaluate_setter(_: Trainer, __: bool) -> None:
174-
# Remove in v2.0.0
174+
# TODO: Remove in v2.0.0
175175
raise AttributeError(
176176
"The `Trainer.verbose_evaluate` was deprecated in v1.6 and is no longer accessible as of v1.8."
177177
" Please use `trainer.{validate,test}_loop.verbose` instead.",
178178
)
179179

180180

181181
def _run_stage(_: Trainer) -> None:
182-
# Remove in v2.0.0
182+
# TODO: Remove in v2.0.0
183183
raise NotImplementedError(
184184
"`Trainer.run_stage` was deprecated in v1.6 and is no longer supported as of v1.8."
185185
" Please use `Trainer.{fit,validate,test,predict}` instead."
186186
)
187187

188188

189189
def _call_hook(_: Trainer, *__: Any, **___: Any) -> Any:
190-
# Remove in v2.0.0
190+
# TODO: Remove in v2.0.0
191191
raise NotImplementedError("`Trainer.call_hook` was deprecated in v1.6 and is no longer supported as of v1.8.")
192192

193193

src/pytorch_lightning/trainer/configuration_validator.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ def verify_loop_configurations(trainer: "pl.Trainer") -> None:
5757
_check_on_pretrain_routine(model)
5858
# TODO: Delete this check in v2.0
5959
_check_deprecated_logger_methods(trainer)
60+
# TODO: Delete this check in v2.0
61+
_check_unsupported_datamodule_hooks(trainer)
6062

6163

6264
def __verify_train_val_loop_configuration(trainer: "pl.Trainer", model: "pl.LightningModule") -> None:
@@ -278,3 +280,19 @@ def _check_deprecated_logger_methods(trainer: "pl.Trainer") -> None:
278280
f"`{type(logger).__name__}.agg_and_log_metrics` was deprecated in v1.6 and is no longer supported as of"
279281
" v1.8."
280282
)
283+
284+
285+
def _check_unsupported_datamodule_hooks(trainer: "pl.Trainer") -> None:
286+
datahook_selector = trainer._data_connector._datahook_selector
287+
assert datahook_selector is not None
288+
289+
if is_overridden("on_save_checkpoint", datahook_selector.datamodule):
290+
raise NotImplementedError(
291+
"`LightningDataModule.on_save_checkpoint` was deprecated in v1.6 and removed in v1.8."
292+
" Use `state_dict` instead."
293+
)
294+
if is_overridden("on_load_checkpoint", datahook_selector.datamodule):
295+
raise NotImplementedError(
296+
"`LightningDataModule.on_load_checkpoint` was deprecated in v1.6 and removed in v1.8."
297+
" Use `load_state_dict` instead."
298+
)
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import pytest
15+
16+
from pytorch_lightning import Trainer
17+
from pytorch_lightning.demos.boring_classes import BoringDataModule, BoringModel
18+
19+
20+
def test_v2_0_0_unsupported_datamodule_on_save_load_checkpoint():
21+
datamodule = BoringDataModule()
22+
with pytest.raises(
23+
NotImplementedError,
24+
match="`LightningDataModule.on_save_checkpoint` was deprecated in v1.6 and removed in v1.8.",
25+
):
26+
datamodule.on_save_checkpoint({})
27+
28+
with pytest.raises(
29+
NotImplementedError,
30+
match="`LightningDataModule.on_load_checkpoint` was deprecated in v1.6 and removed in v1.8.",
31+
):
32+
datamodule.on_load_checkpoint({})
33+
34+
class OnSaveDataModule(BoringDataModule):
35+
def on_save_checkpoint(self, checkpoint):
36+
pass
37+
38+
class OnLoadDataModule(BoringDataModule):
39+
def on_load_checkpoint(self, checkpoint):
40+
pass
41+
42+
trainer = Trainer()
43+
model = BoringModel()
44+
45+
with pytest.raises(
46+
NotImplementedError,
47+
match="`LightningDataModule.on_save_checkpoint` was deprecated in v1.6 and removed in v1.8.",
48+
):
49+
trainer.fit(model, OnSaveDataModule())
50+
51+
with pytest.raises(
52+
NotImplementedError,
53+
match="`LightningDataModule.on_load_checkpoint` was deprecated in v1.6 and removed in v1.8.",
54+
):
55+
trainer.fit(model, OnLoadDataModule())

0 commit comments

Comments
 (0)