Skip to content

Commit 8105b09

Browse files
committed
create read-only property set when checkpoints loaded via Trainer.{fit,validate,test,predict}, deprecate Trainer.{validated,tested,predicted}_ckpt_path
1 parent ec1379d commit 8105b09

File tree

3 files changed

+112
-23
lines changed

3 files changed

+112
-23
lines changed

pytorch_lightning/trainer/trainer.py

Lines changed: 92 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -480,10 +480,14 @@ def __init__(
480480
# default .predict() loop
481481
self.predict_loop = PredictionLoop()
482482

483-
# .validate() and .test() set this when they load a checkpoint
484-
self.validated_ckpt_path: Optional[str] = None
485-
self.tested_ckpt_path: Optional[str] = None
486-
self.predicted_ckpt_path: Optional[str] = None
483+
# set when a checkpoint is loaded via `Trainer.{fit,validate,test,predict}`.
484+
self._ckpt_path: Optional[str] = None
485+
486+
# .validate(), predict() and .test() set these when they load a checkpoint. They will be removed in favor of
487+
# the unified read-only `Trainer.ckpt_path` attribute in v1.8
488+
self._validated_ckpt_path: Optional[str] = None # TODO: remove in v1.8
489+
self._tested_ckpt_path: Optional[str] = None # TODO: remove in v1.8
490+
self._predicted_ckpt_path: Optional[str] = None # TODO: remove in v1.8
487491

488492
# todo: remove in v1.7
489493
self._weights_summary: Optional[str] = None
@@ -758,7 +762,10 @@ def _fit_impl(
758762

759763
# TODO: ckpt_path only in v2.0
760764
ckpt_path = ckpt_path or self.resume_from_checkpoint
761-
results = self._run(model, ckpt_path=ckpt_path)
765+
self._ckpt_path = self.__set_ckpt_path(
766+
ckpt_path, model_provided=model, model_connected=self.lightning_module is not None
767+
)
768+
results = self._run(model, ckpt_path=self.ckpt_path)
762769

763770
assert self.state.stopped
764771
self.training = False
@@ -837,12 +844,14 @@ def _validate_impl(
837844
# links data to the trainer
838845
self._data_connector.attach_data(model, val_dataloaders=dataloaders, datamodule=datamodule)
839846

840-
self.validated_ckpt_path = self.__set_ckpt_path(
847+
self._ckpt_path = self.__set_ckpt_path(
841848
ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
842849
)
843850

851+
self._validated_ckpt_path = self.ckpt_path # TODO: remove in v1.8
852+
844853
# run validate
845-
results = self._run(model, ckpt_path=self.validated_ckpt_path)
854+
results = self._run(model, ckpt_path=self.ckpt_path)
846855

847856
assert self.state.stopped
848857
self.validating = False
@@ -923,12 +932,14 @@ def _test_impl(
923932
# links data to the trainer
924933
self._data_connector.attach_data(model, test_dataloaders=dataloaders, datamodule=datamodule)
925934

926-
self.tested_ckpt_path = self.__set_ckpt_path(
935+
self._ckpt_path = self.__set_ckpt_path(
927936
ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
928937
)
929938

939+
self._tested_ckpt_path = self.ckpt_path # TODO: remove in v1.8
940+
930941
# run test
931-
results = self._run(model, ckpt_path=self.tested_ckpt_path)
942+
results = self._run(model, ckpt_path=self.ckpt_path)
932943

933944
assert self.state.stopped
934945
self.testing = False
@@ -1009,11 +1020,13 @@ def _predict_impl(
10091020
# links data to the trainer
10101021
self._data_connector.attach_data(model, predict_dataloaders=dataloaders, datamodule=datamodule)
10111022

1012-
self.predicted_ckpt_path = self.__set_ckpt_path(
1023+
self._ckpt_path = self.__set_ckpt_path(
10131024
ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
10141025
)
10151026

1016-
results = self._run(model, ckpt_path=self.predicted_ckpt_path)
1027+
self._predicted_ckpt_path = self.ckpt_path # TODO: remove in v1.8
1028+
1029+
results = self._run(model, ckpt_path=self.ckpt_path)
10171030

10181031
assert self.state.stopped
10191032
self.predicting = False
@@ -2217,6 +2230,74 @@ def resume_from_checkpoint(self) -> Optional[Union[str, Path]]:
22172230

22182231
return resume_from_checkpoint
22192232

2233+
@property
2234+
def ckpt_path(self) -> Optional[str]:
2235+
"""Set to the path/URL of checkpoints loaded via :meth:`~pytorch_lightning.trainer.trainer.Trainer.fit`,
2236+
:meth:`~pytorch_lightning.trainer.trainer.Trainer.validate`,
2237+
:meth:`~pytorch_lightning.trainer.trainer.Trainer.test`, or
2238+
:meth:`~pytorch_lightning.trainer.trainer.Trainer.predict`. ``None`` otherwise."""
2239+
return self._ckpt_path
2240+
2241+
@property
2242+
def validated_ckpt_path(self) -> Optional[str]:
2243+
rank_zero_deprecation(
2244+
"The `Trainer.validated_ckpt_path` attribute was deprecated in v1.6 and will be removed in v1.8. The"
2245+
" path of checkpoints loaded via `Trainer.{fit,validate,test,predict}` should be accessed via"
2246+
" `Trainer.ckpt_path` instead.",
2247+
stacklevel=5,
2248+
)
2249+
return self._validated_ckpt_path
2250+
2251+
@validated_ckpt_path.setter
2252+
def validated_ckpt_path(self, ckpt_path: Optional[str]) -> None:
2253+
rank_zero_deprecation(
2254+
"The `Trainer.validated_ckpt_path` attribute was deprecated in v1.6 and will be removed in v1.8. The"
2255+
" path of checkpoints loaded via `Trainer.{fit,validate,test,predict}` should be accessed via the read-only"
2256+
" `Trainer.ckpt_path`.",
2257+
stacklevel=5,
2258+
)
2259+
self._validated_ckpt_path = ckpt_path
2260+
2261+
@property
2262+
def tested_ckpt_path(self) -> Optional[str]:
2263+
rank_zero_deprecation(
2264+
"The `Trainer.tested_ckpt_path` attribute was deprecated in v1.6 and will be removed in v1.8. The"
2265+
" path of checkpoints loaded via `Trainer.{fit,validate,test,predict}` should be accessed via"
2266+
" `Trainer.ckpt_path` instead.",
2267+
stacklevel=5,
2268+
)
2269+
return self._tested_ckpt_path
2270+
2271+
@tested_ckpt_path.setter
2272+
def tested_ckpt_path(self, ckpt_path: Optional[str]) -> None:
2273+
rank_zero_deprecation(
2274+
"The `Trainer.tested_ckpt_path` attribute was deprecated in v1.6 and will be removed in v1.8. The"
2275+
" path of checkpoints loaded via `Trainer.{fit,validate,test,predict}` should be accessed via the read-only"
2276+
" `Trainer.ckpt_path` instead.",
2277+
stacklevel=5,
2278+
)
2279+
self._tested_ckpt_path = ckpt_path
2280+
2281+
@property
2282+
def predicted_ckpt_path(self) -> Optional[str]:
2283+
rank_zero_deprecation(
2284+
"The `Trainer.predicted_ckpt_path` attribute was deprecated in v1.6 and will be removed in v1.8. The"
2285+
" path of checkpoints loaded via `Trainer.{fit,validate,test,predict}` should be accessed via"
2286+
" `Trainer.ckpt_path` instead.",
2287+
stacklevel=5,
2288+
)
2289+
return self._predicted_ckpt_path
2290+
2291+
@predicted_ckpt_path.setter
2292+
def predicted_ckpt_path(self, ckpt_path: Optional[str]) -> None:
2293+
rank_zero_deprecation(
2294+
"The `Trainer.predicted_ckpt_path` attribute was deprecated in v1.6 and will be removed in v1.8. The"
2295+
" path of checkpoints loaded via `Trainer.{fit,validate,test,predict}` should be accessed via the read-only"
2296+
" `Trainer.ckpt_path` instead.",
2297+
stacklevel=5,
2298+
)
2299+
self._predicted_ckpt_path = ckpt_path
2300+
22202301
def save_checkpoint(self, filepath: _PATH, weights_only: bool = False) -> None:
22212302
r"""
22222303
Runs routine to create a checkpoint.

tests/deprecated_api/test_remove_1-8.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,16 @@ def test_v1_8_0_trainer_verbose_evaluate():
146146
trainer.verbose_evaluate = False
147147

148148

149+
@pytest.mark.parametrize("fn_prefix", ["validated", "tested", "predicted"], ids=["validated", "tested", "predicted"])
150+
def test_v1_8_0_trainer_ckpt_path_attributes(fn_prefix: str):
151+
test_attr = f"{fn_prefix}_ckpt_path"
152+
trainer = Trainer()
153+
with pytest.deprecated_call(match=f"{test_attr}` attribute was deprecated in v1.6 and will be removed in v1.8"):
154+
_ = getattr(trainer, test_attr)
155+
with pytest.deprecated_call(match=f"{test_attr}` attribute was deprecated in v1.6 and will be removed in v1.8"):
156+
setattr(trainer, test_attr, "v")
157+
158+
149159
def test_v1_8_0_deprecated_trainer_should_rank_save_checkpoint(tmpdir):
150160
trainer = Trainer()
151161
with pytest.deprecated_call(

tests/trainer/test_trainer.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -686,8 +686,7 @@ def predict_step(self, batch, *_):
686686
trainer.fit(model)
687687

688688
trainer_fn = getattr(trainer, fn)
689-
path_attr = f"{fn}{'d' if fn == 'validate' else 'ed'}_ckpt_path"
690-
assert getattr(trainer, path_attr) is None
689+
assert getattr(trainer, "ckpt_path") is None
691690

692691
if ckpt_path == "best":
693692
# ckpt_path is 'best', meaning we load the best weights
@@ -698,20 +697,20 @@ def predict_step(self, batch, *_):
698697
trainer_fn(model, ckpt_path=ckpt_path)
699698
else:
700699
trainer_fn(ckpt_path=ckpt_path)
701-
assert getattr(trainer, path_attr) == trainer.checkpoint_callback.best_model_path
700+
assert getattr(trainer, "ckpt_path") == trainer.checkpoint_callback.best_model_path
702701

703702
trainer_fn(model, ckpt_path=ckpt_path)
704-
assert getattr(trainer, path_attr) == trainer.checkpoint_callback.best_model_path
703+
assert getattr(trainer, "ckpt_path") == trainer.checkpoint_callback.best_model_path
705704
elif ckpt_path is None:
706705
# ckpt_path is None, meaning we don't load any checkpoints and use the provided model
707706
trainer_fn(model, ckpt_path=ckpt_path)
708-
assert getattr(trainer, path_attr) is None
707+
assert getattr(trainer, "ckpt_path") is None
709708

710709
if save_top_k > 0:
711710
# ckpt_path is None with no model provided means load the best weights
712711
with pytest.warns(UserWarning, match="The best model of the previous `fit` call will be used"):
713712
trainer_fn(ckpt_path=ckpt_path)
714-
assert getattr(trainer, path_attr) == trainer.checkpoint_callback.best_model_path
713+
assert getattr(trainer, "ckpt_path") == trainer.checkpoint_callback.best_model_path
715714
else:
716715
# specific checkpoint, pick one from saved ones
717716
if save_top_k == 0:
@@ -724,10 +723,10 @@ def predict_step(self, batch, *_):
724723
].absolute()
725724
)
726725
trainer_fn(ckpt_path=ckpt_path)
727-
assert getattr(trainer, path_attr) == ckpt_path
726+
assert getattr(trainer, "ckpt_path") == ckpt_path
728727

729728
trainer_fn(model, ckpt_path=ckpt_path)
730-
assert getattr(trainer, path_attr) == ckpt_path
729+
assert getattr(trainer, "ckpt_path") == ckpt_path
731730

732731

733732
@pytest.mark.parametrize("enable_checkpointing", (False, True))
@@ -758,15 +757,14 @@ def predict_step(self, batch, *_):
758757
trainer.fit(model)
759758

760759
trainer_fn = getattr(trainer, fn)
761-
path_attr = f"{fn}{'d' if fn == 'validate' else 'ed'}_ckpt_path"
762-
assert getattr(trainer, path_attr) is None
760+
assert getattr(trainer, "ckpt_path") is None
763761

764762
if enable_checkpointing:
765763
trainer_fn(ckpt_path="best")
766-
assert getattr(trainer, path_attr) == trainer.checkpoint_callback.best_model_path
764+
assert getattr(trainer, "ckpt_path") == trainer.checkpoint_callback.best_model_path
767765

768766
trainer_fn(model, ckpt_path="best")
769-
assert getattr(trainer, path_attr) == trainer.checkpoint_callback.best_model_path
767+
assert getattr(trainer, "ckpt_path") == trainer.checkpoint_callback.best_model_path
770768
else:
771769
with pytest.raises(MisconfigurationException, match="`ModelCheckpoint` is not configured."):
772770
trainer_fn(ckpt_path="best")

0 commit comments

Comments
 (0)