@@ -480,10 +480,14 @@ def __init__(
480
480
# default .predict() loop
481
481
self .predict_loop = PredictionLoop ()
482
482
483
- # .validate() and .test() set this when they load a checkpoint
484
- self .validated_ckpt_path : Optional [str ] = None
485
- self .tested_ckpt_path : Optional [str ] = None
486
- self .predicted_ckpt_path : Optional [str ] = None
483
+ # set when a checkpoint is loaded via `Trainer.{fit,validate,test,predict}`.
484
+ self ._ckpt_path : Optional [str ] = None
485
+
486
+ # .validate(), predict() and .test() set these when they load a checkpoint. They will be removed in favor of
487
+ # the unified read-only `Trainer.ckpt_path` attribute in v1.8
488
+ self ._validated_ckpt_path : Optional [str ] = None # TODO: remove in v1.8
489
+ self ._tested_ckpt_path : Optional [str ] = None # TODO: remove in v1.8
490
+ self ._predicted_ckpt_path : Optional [str ] = None # TODO: remove in v1.8
487
491
488
492
# todo: remove in v1.7
489
493
self ._weights_summary : Optional [str ] = None
@@ -758,7 +762,10 @@ def _fit_impl(
758
762
759
763
# TODO: ckpt_path only in v2.0
760
764
ckpt_path = ckpt_path or self .resume_from_checkpoint
761
- results = self ._run (model , ckpt_path = ckpt_path )
765
+ self ._ckpt_path = self .__set_ckpt_path (
766
+ ckpt_path , model_provided = model , model_connected = self .lightning_module is not None
767
+ )
768
+ results = self ._run (model , ckpt_path = self .ckpt_path )
762
769
763
770
assert self .state .stopped
764
771
self .training = False
@@ -837,12 +844,14 @@ def _validate_impl(
837
844
# links data to the trainer
838
845
self ._data_connector .attach_data (model , val_dataloaders = dataloaders , datamodule = datamodule )
839
846
840
- self .validated_ckpt_path = self .__set_ckpt_path (
847
+ self ._ckpt_path = self .__set_ckpt_path (
841
848
ckpt_path , model_provided = model_provided , model_connected = self .lightning_module is not None
842
849
)
843
850
851
+ self ._validated_ckpt_path = self .ckpt_path # TODO: remove in v1.8
852
+
844
853
# run validate
845
- results = self ._run (model , ckpt_path = self .validated_ckpt_path )
854
+ results = self ._run (model , ckpt_path = self .ckpt_path )
846
855
847
856
assert self .state .stopped
848
857
self .validating = False
@@ -923,12 +932,14 @@ def _test_impl(
923
932
# links data to the trainer
924
933
self ._data_connector .attach_data (model , test_dataloaders = dataloaders , datamodule = datamodule )
925
934
926
- self .tested_ckpt_path = self .__set_ckpt_path (
935
+ self ._ckpt_path = self .__set_ckpt_path (
927
936
ckpt_path , model_provided = model_provided , model_connected = self .lightning_module is not None
928
937
)
929
938
939
+ self ._tested_ckpt_path = self .ckpt_path # TODO: remove in v1.8
940
+
930
941
# run test
931
- results = self ._run (model , ckpt_path = self .tested_ckpt_path )
942
+ results = self ._run (model , ckpt_path = self .ckpt_path )
932
943
933
944
assert self .state .stopped
934
945
self .testing = False
@@ -1009,11 +1020,13 @@ def _predict_impl(
1009
1020
# links data to the trainer
1010
1021
self ._data_connector .attach_data (model , predict_dataloaders = dataloaders , datamodule = datamodule )
1011
1022
1012
- self .predicted_ckpt_path = self .__set_ckpt_path (
1023
+ self ._ckpt_path = self .__set_ckpt_path (
1013
1024
ckpt_path , model_provided = model_provided , model_connected = self .lightning_module is not None
1014
1025
)
1015
1026
1016
- results = self ._run (model , ckpt_path = self .predicted_ckpt_path )
1027
+ self ._predicted_ckpt_path = self .ckpt_path # TODO: remove in v1.8
1028
+
1029
+ results = self ._run (model , ckpt_path = self .ckpt_path )
1017
1030
1018
1031
assert self .state .stopped
1019
1032
self .predicting = False
@@ -2219,6 +2232,74 @@ def resume_from_checkpoint(self) -> Optional[Union[str, Path]]:
2219
2232
2220
2233
return resume_from_checkpoint
2221
2234
2235
+ @property
2236
+ def ckpt_path (self ) -> Optional [str ]:
2237
+ """Set to the path/URL of a checkpoint loaded via :meth:`~pytorch_lightning.trainer.trainer.Trainer.fit`,
2238
+ :meth:`~pytorch_lightning.trainer.trainer.Trainer.validate`,
2239
+ :meth:`~pytorch_lightning.trainer.trainer.Trainer.test`, or
2240
+ :meth:`~pytorch_lightning.trainer.trainer.Trainer.predict`. ``None`` otherwise."""
2241
+ return self ._ckpt_path
2242
+
2243
+ @property
2244
+ def validated_ckpt_path (self ) -> Optional [str ]:
2245
+ rank_zero_deprecation (
2246
+ "The `Trainer.validated_ckpt_path` attribute was deprecated in v1.6 and will be removed in v1.8. The"
2247
+ " path of a checkpoint loaded via `Trainer.{fit,validate,test,predict}` should be accessed via"
2248
+ " `Trainer.ckpt_path` instead." ,
2249
+ stacklevel = 5 ,
2250
+ )
2251
+ return self ._validated_ckpt_path
2252
+
2253
+ @validated_ckpt_path .setter
2254
+ def validated_ckpt_path (self , ckpt_path : Optional [str ]) -> None :
2255
+ rank_zero_deprecation (
2256
+ "The `Trainer.validated_ckpt_path` attribute was deprecated in v1.6 and will be removed in v1.8. The"
2257
+ " path of a checkpoint loaded via `Trainer.{fit,validate,test,predict}` should be accessed via the"
2258
+ " read-only `Trainer.ckpt_path`." ,
2259
+ stacklevel = 5 ,
2260
+ )
2261
+ self ._validated_ckpt_path = ckpt_path
2262
+
2263
+ @property
2264
+ def tested_ckpt_path (self ) -> Optional [str ]:
2265
+ rank_zero_deprecation (
2266
+ "The `Trainer.tested_ckpt_path` attribute was deprecated in v1.6 and will be removed in v1.8. The"
2267
+ " path of a checkpoint loaded via `Trainer.{fit,validate,test,predict}` should be accessed via"
2268
+ " `Trainer.ckpt_path` instead." ,
2269
+ stacklevel = 5 ,
2270
+ )
2271
+ return self ._tested_ckpt_path
2272
+
2273
+ @tested_ckpt_path .setter
2274
+ def tested_ckpt_path (self , ckpt_path : Optional [str ]) -> None :
2275
+ rank_zero_deprecation (
2276
+ "The `Trainer.tested_ckpt_path` attribute was deprecated in v1.6 and will be removed in v1.8. The"
2277
+ " path of a checkpoint loaded via `Trainer.{fit,validate,test,predict}` should be accessed via the"
2278
+ " read-only `Trainer.ckpt_path` instead." ,
2279
+ stacklevel = 5 ,
2280
+ )
2281
+ self ._tested_ckpt_path = ckpt_path
2282
+
2283
+ @property
2284
+ def predicted_ckpt_path (self ) -> Optional [str ]:
2285
+ rank_zero_deprecation (
2286
+ "The `Trainer.predicted_ckpt_path` attribute was deprecated in v1.6 and will be removed in v1.8. The"
2287
+ " path of a checkpoint loaded via `Trainer.{fit,validate,test,predict}` should be accessed via"
2288
+ " `Trainer.ckpt_path` instead." ,
2289
+ stacklevel = 5 ,
2290
+ )
2291
+ return self ._predicted_ckpt_path
2292
+
2293
+ @predicted_ckpt_path .setter
2294
+ def predicted_ckpt_path (self , ckpt_path : Optional [str ]) -> None :
2295
+ rank_zero_deprecation (
2296
+ "The `Trainer.predicted_ckpt_path` attribute was deprecated in v1.6 and will be removed in v1.8. The"
2297
+ " path of a checkpoint loaded via `Trainer.{fit,validate,test,predict}` should be accessed via the"
2298
+ " read-only `Trainer.ckpt_path` instead." ,
2299
+ stacklevel = 5 ,
2300
+ )
2301
+ self ._predicted_ckpt_path = ckpt_path
2302
+
2222
2303
def save_checkpoint (self , filepath : _PATH , weights_only : bool = False ) -> None :
2223
2304
r"""
2224
2305
Runs routine to create a checkpoint.
0 commit comments