Skip to content

Commit 3bc2407

Browse files
speediedancarmocca
andauthored
Allow access to ckpt_path within context of fit() (#11696)
Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 7da931d commit 3bc2407

File tree

6 files changed

+144
-51
lines changed

6 files changed

+144
-51
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
258258

259259
### Deprecated
260260

261+
- Deprecated `Trainer.{validated,tested,predicted}_ckpt_path` and replaced with read-only property `Trainer.ckpt_path` set when checkpoints loaded via `Trainer.{fit,validate,test,predict}` ([#11696](https://github.com/PyTorchLightning/pytorch-lightning/pull/11696))
262+
261263
- Deprecated `ClusterEnvironment.master_{address,port}` in favor of `ClusterEnvironment.main_{address,port}` ([#10103](https://github.com/PyTorchLightning/pytorch-lightning/pull/10103))
262264

263265

pytorch_lightning/callbacks/finetuning.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
9999
if self._restarting:
100100
named_parameters = dict(pl_module.named_parameters())
101101
for opt_idx, optimizer in enumerate(trainer.optimizers):
102-
param_groups = self.__apply_mapping_to_param_groups(
102+
param_groups = self._apply_mapping_to_param_groups(
103103
self._internal_optimizer_metadata[opt_idx], named_parameters
104104
)
105105
optimizer.param_groups = param_groups
@@ -245,7 +245,7 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: O
245245
self.freeze_before_training(pl_module)
246246

247247
@staticmethod
248-
def __apply_mapping_to_param_groups(param_groups: List[Dict[str, Any]], mapping: dict) -> List[Dict[str, Any]]:
248+
def _apply_mapping_to_param_groups(param_groups: List[Dict[str, Any]], mapping: dict) -> List[Dict[str, Any]]:
249249
output = []
250250
for g in param_groups:
251251
# skip params to save memory
@@ -263,13 +263,13 @@ def _store(
263263
) -> None:
264264
mapping = {p: n for n, p in pl_module.named_parameters()}
265265
if opt_idx not in self._internal_optimizer_metadata:
266-
self._internal_optimizer_metadata[opt_idx] = self.__apply_mapping_to_param_groups(
266+
self._internal_optimizer_metadata[opt_idx] = self._apply_mapping_to_param_groups(
267267
current_param_groups, mapping
268268
)
269269
elif num_param_groups != len(current_param_groups):
270270
# save new param_groups possibly created by the users.
271271
self._internal_optimizer_metadata[opt_idx].extend(
272-
self.__apply_mapping_to_param_groups(current_param_groups[num_param_groups:], mapping)
272+
self._apply_mapping_to_param_groups(current_param_groups[num_param_groups:], mapping)
273273
)
274274

275275
def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:

pytorch_lightning/trainer/trainer.py

Lines changed: 92 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -480,10 +480,14 @@ def __init__(
480480
# default .predict() loop
481481
self.predict_loop = PredictionLoop()
482482

483-
# .validate() and .test() set this when they load a checkpoint
484-
self.validated_ckpt_path: Optional[str] = None
485-
self.tested_ckpt_path: Optional[str] = None
486-
self.predicted_ckpt_path: Optional[str] = None
483+
# set when a checkpoint is loaded via `Trainer.{fit,validate,test,predict}`.
484+
self._ckpt_path: Optional[str] = None
485+
486+
# .validate(), predict() and .test() set these when they load a checkpoint. They will be removed in favor of
487+
# the unified read-only `Trainer.ckpt_path` attribute in v1.8
488+
self._validated_ckpt_path: Optional[str] = None # TODO: remove in v1.8
489+
self._tested_ckpt_path: Optional[str] = None # TODO: remove in v1.8
490+
self._predicted_ckpt_path: Optional[str] = None # TODO: remove in v1.8
487491

488492
# todo: remove in v1.7
489493
self._weights_summary: Optional[str] = None
@@ -758,7 +762,10 @@ def _fit_impl(
758762

759763
# TODO: ckpt_path only in v2.0
760764
ckpt_path = ckpt_path or self.resume_from_checkpoint
761-
results = self._run(model, ckpt_path=ckpt_path)
765+
self._ckpt_path = self.__set_ckpt_path(
766+
ckpt_path, model_provided=model, model_connected=self.lightning_module is not None
767+
)
768+
results = self._run(model, ckpt_path=self.ckpt_path)
762769

763770
assert self.state.stopped
764771
self.training = False
@@ -837,12 +844,14 @@ def _validate_impl(
837844
# links data to the trainer
838845
self._data_connector.attach_data(model, val_dataloaders=dataloaders, datamodule=datamodule)
839846

840-
self.validated_ckpt_path = self.__set_ckpt_path(
847+
self._ckpt_path = self.__set_ckpt_path(
841848
ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
842849
)
843850

851+
self._validated_ckpt_path = self.ckpt_path # TODO: remove in v1.8
852+
844853
# run validate
845-
results = self._run(model, ckpt_path=self.validated_ckpt_path)
854+
results = self._run(model, ckpt_path=self.ckpt_path)
846855

847856
assert self.state.stopped
848857
self.validating = False
@@ -923,12 +932,14 @@ def _test_impl(
923932
# links data to the trainer
924933
self._data_connector.attach_data(model, test_dataloaders=dataloaders, datamodule=datamodule)
925934

926-
self.tested_ckpt_path = self.__set_ckpt_path(
935+
self._ckpt_path = self.__set_ckpt_path(
927936
ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
928937
)
929938

939+
self._tested_ckpt_path = self.ckpt_path # TODO: remove in v1.8
940+
930941
# run test
931-
results = self._run(model, ckpt_path=self.tested_ckpt_path)
942+
results = self._run(model, ckpt_path=self.ckpt_path)
932943

933944
assert self.state.stopped
934945
self.testing = False
@@ -1009,11 +1020,13 @@ def _predict_impl(
10091020
# links data to the trainer
10101021
self._data_connector.attach_data(model, predict_dataloaders=dataloaders, datamodule=datamodule)
10111022

1012-
self.predicted_ckpt_path = self.__set_ckpt_path(
1023+
self._ckpt_path = self.__set_ckpt_path(
10131024
ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
10141025
)
10151026

1016-
results = self._run(model, ckpt_path=self.predicted_ckpt_path)
1027+
self._predicted_ckpt_path = self.ckpt_path # TODO: remove in v1.8
1028+
1029+
results = self._run(model, ckpt_path=self.ckpt_path)
10171030

10181031
assert self.state.stopped
10191032
self.predicting = False
@@ -2219,6 +2232,74 @@ def resume_from_checkpoint(self) -> Optional[Union[str, Path]]:
22192232

22202233
return resume_from_checkpoint
22212234

2235+
@property
2236+
def ckpt_path(self) -> Optional[str]:
2237+
"""Set to the path/URL of a checkpoint loaded via :meth:`~pytorch_lightning.trainer.trainer.Trainer.fit`,
2238+
:meth:`~pytorch_lightning.trainer.trainer.Trainer.validate`,
2239+
:meth:`~pytorch_lightning.trainer.trainer.Trainer.test`, or
2240+
:meth:`~pytorch_lightning.trainer.trainer.Trainer.predict`. ``None`` otherwise."""
2241+
return self._ckpt_path
2242+
2243+
@property
2244+
def validated_ckpt_path(self) -> Optional[str]:
2245+
rank_zero_deprecation(
2246+
"The `Trainer.validated_ckpt_path` attribute was deprecated in v1.6 and will be removed in v1.8. The"
2247+
" path of a checkpoint loaded via `Trainer.{fit,validate,test,predict}` should be accessed via"
2248+
" `Trainer.ckpt_path` instead.",
2249+
stacklevel=5,
2250+
)
2251+
return self._validated_ckpt_path
2252+
2253+
@validated_ckpt_path.setter
2254+
def validated_ckpt_path(self, ckpt_path: Optional[str]) -> None:
2255+
rank_zero_deprecation(
2256+
"The `Trainer.validated_ckpt_path` attribute was deprecated in v1.6 and will be removed in v1.8. The"
2257+
" path of a checkpoint loaded via `Trainer.{fit,validate,test,predict}` should be accessed via the"
2258+
" read-only `Trainer.ckpt_path`.",
2259+
stacklevel=5,
2260+
)
2261+
self._validated_ckpt_path = ckpt_path
2262+
2263+
@property
2264+
def tested_ckpt_path(self) -> Optional[str]:
2265+
rank_zero_deprecation(
2266+
"The `Trainer.tested_ckpt_path` attribute was deprecated in v1.6 and will be removed in v1.8. The"
2267+
" path of a checkpoint loaded via `Trainer.{fit,validate,test,predict}` should be accessed via"
2268+
" `Trainer.ckpt_path` instead.",
2269+
stacklevel=5,
2270+
)
2271+
return self._tested_ckpt_path
2272+
2273+
@tested_ckpt_path.setter
2274+
def tested_ckpt_path(self, ckpt_path: Optional[str]) -> None:
2275+
rank_zero_deprecation(
2276+
"The `Trainer.tested_ckpt_path` attribute was deprecated in v1.6 and will be removed in v1.8. The"
2277+
" path of a checkpoint loaded via `Trainer.{fit,validate,test,predict}` should be accessed via the"
2278+
" read-only `Trainer.ckpt_path` instead.",
2279+
stacklevel=5,
2280+
)
2281+
self._tested_ckpt_path = ckpt_path
2282+
2283+
@property
2284+
def predicted_ckpt_path(self) -> Optional[str]:
2285+
rank_zero_deprecation(
2286+
"The `Trainer.predicted_ckpt_path` attribute was deprecated in v1.6 and will be removed in v1.8. The"
2287+
" path of a checkpoint loaded via `Trainer.{fit,validate,test,predict}` should be accessed via"
2288+
" `Trainer.ckpt_path` instead.",
2289+
stacklevel=5,
2290+
)
2291+
return self._predicted_ckpt_path
2292+
2293+
@predicted_ckpt_path.setter
2294+
def predicted_ckpt_path(self, ckpt_path: Optional[str]) -> None:
2295+
rank_zero_deprecation(
2296+
"The `Trainer.predicted_ckpt_path` attribute was deprecated in v1.6 and will be removed in v1.8. The"
2297+
" path of a checkpoint loaded via `Trainer.{fit,validate,test,predict}` should be accessed via the"
2298+
" read-only `Trainer.ckpt_path` instead.",
2299+
stacklevel=5,
2300+
)
2301+
self._predicted_ckpt_path = ckpt_path
2302+
22222303
def save_checkpoint(self, filepath: _PATH, weights_only: bool = False) -> None:
22232304
r"""
22242305
Runs routine to create a checkpoint.

tests/callbacks/test_finetuning_callback.py

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -287,34 +287,36 @@ def configure_optimizers(self):
287287
trainer.fit(model)
288288

289289

290-
def test_complex_nested_model():
291-
"""Test flattening, freezing, and thawing of models which contain parent (non-leaf) modules with parameters
292-
directly themselves rather than exclusively their submodules containing parameters."""
290+
class ConvBlock(nn.Module):
291+
def __init__(self, in_channels, out_channels):
292+
super().__init__()
293+
self.conv = nn.Conv2d(in_channels, out_channels, 3)
294+
self.act = nn.ReLU()
295+
self.bn = nn.BatchNorm2d(out_channels)
293296

294-
class ConvBlock(nn.Module):
295-
def __init__(self, in_channels, out_channels):
296-
super().__init__()
297-
self.conv = nn.Conv2d(in_channels, out_channels, 3)
298-
self.act = nn.ReLU()
299-
self.bn = nn.BatchNorm2d(out_channels)
297+
def forward(self, x):
298+
x = self.conv(x)
299+
x = self.act(x)
300+
return self.bn(x)
300301

301-
def forward(self, x):
302-
x = self.conv(x)
303-
x = self.act(x)
304-
return self.bn(x)
305302

306-
class ConvBlockParam(nn.Module):
307-
def __init__(self, in_channels, out_channels):
308-
super().__init__()
309-
self.module_dict = nn.ModuleDict({"conv": nn.Conv2d(in_channels, out_channels, 3), "act": nn.ReLU()})
310-
# add trivial test parameter to convblock to validate parent (non-leaf) module parameter handling
311-
self.parent_param = nn.Parameter(torch.zeros((1), dtype=torch.float))
312-
self.bn = nn.BatchNorm2d(out_channels)
303+
class ConvBlockParam(nn.Module):
304+
def __init__(self, in_channels, out_channels):
305+
super().__init__()
306+
self.module_dict = nn.ModuleDict({"conv": nn.Conv2d(in_channels, out_channels, 3), "act": nn.ReLU()})
307+
# add trivial test parameter to convblock to validate parent (non-leaf) module parameter handling
308+
self.parent_param = nn.Parameter(torch.zeros((1), dtype=torch.float))
309+
self.bn = nn.BatchNorm2d(out_channels)
313310

314-
def forward(self, x):
315-
x = self.module_dict["conv"](x)
316-
x = self.module_dict["act"](x)
317-
return self.bn(x)
311+
def forward(self, x):
312+
x = self.module_dict["conv"](x)
313+
x = self.module_dict["act"](x)
314+
return self.bn(x)
315+
316+
317+
def test_complex_nested_model():
318+
"""Test flattening, freezing, and thawing of models which contain parent (non-leaf) modules with parameters
319+
directly themselves rather than exclusively their submodules containing parameters."""
318320

319321
model = nn.Sequential(
320322
OrderedDict(

tests/deprecated_api/test_remove_1-8.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,16 @@ def test_v1_8_0_trainer_verbose_evaluate():
146146
trainer.verbose_evaluate = False
147147

148148

149+
@pytest.mark.parametrize("fn_prefix", ["validated", "tested", "predicted"])
150+
def test_v1_8_0_trainer_ckpt_path_attributes(fn_prefix: str):
151+
test_attr = f"{fn_prefix}_ckpt_path"
152+
trainer = Trainer()
153+
with pytest.deprecated_call(match=f"{test_attr}` attribute was deprecated in v1.6 and will be removed in v1.8"):
154+
_ = getattr(trainer, test_attr)
155+
with pytest.deprecated_call(match=f"{test_attr}` attribute was deprecated in v1.6 and will be removed in v1.8"):
156+
setattr(trainer, test_attr, "v")
157+
158+
149159
def test_v1_8_0_deprecated_trainer_should_rank_save_checkpoint(tmpdir):
150160
trainer = Trainer()
151161
with pytest.deprecated_call(

tests/trainer/test_trainer.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -686,8 +686,7 @@ def predict_step(self, batch, *_):
686686
trainer.fit(model)
687687

688688
trainer_fn = getattr(trainer, fn)
689-
path_attr = f"{fn}{'d' if fn == 'validate' else 'ed'}_ckpt_path"
690-
assert getattr(trainer, path_attr) is None
689+
assert getattr(trainer, "ckpt_path") is None
691690

692691
if ckpt_path == "best":
693692
# ckpt_path is 'best', meaning we load the best weights
@@ -698,20 +697,20 @@ def predict_step(self, batch, *_):
698697
trainer_fn(model, ckpt_path=ckpt_path)
699698
else:
700699
trainer_fn(ckpt_path=ckpt_path)
701-
assert getattr(trainer, path_attr) == trainer.checkpoint_callback.best_model_path
700+
assert getattr(trainer, "ckpt_path") == trainer.checkpoint_callback.best_model_path
702701

703702
trainer_fn(model, ckpt_path=ckpt_path)
704-
assert getattr(trainer, path_attr) == trainer.checkpoint_callback.best_model_path
703+
assert getattr(trainer, "ckpt_path") == trainer.checkpoint_callback.best_model_path
705704
elif ckpt_path is None:
706705
# ckpt_path is None, meaning we don't load any checkpoints and use the provided model
707706
trainer_fn(model, ckpt_path=ckpt_path)
708-
assert getattr(trainer, path_attr) is None
707+
assert getattr(trainer, "ckpt_path") is None
709708

710709
if save_top_k > 0:
711710
# ckpt_path is None with no model provided means load the best weights
712711
with pytest.warns(UserWarning, match="The best model of the previous `fit` call will be used"):
713712
trainer_fn(ckpt_path=ckpt_path)
714-
assert getattr(trainer, path_attr) == trainer.checkpoint_callback.best_model_path
713+
assert getattr(trainer, "ckpt_path") == trainer.checkpoint_callback.best_model_path
715714
else:
716715
# specific checkpoint, pick one from saved ones
717716
if save_top_k == 0:
@@ -724,10 +723,10 @@ def predict_step(self, batch, *_):
724723
].absolute()
725724
)
726725
trainer_fn(ckpt_path=ckpt_path)
727-
assert getattr(trainer, path_attr) == ckpt_path
726+
assert getattr(trainer, "ckpt_path") == ckpt_path
728727

729728
trainer_fn(model, ckpt_path=ckpt_path)
730-
assert getattr(trainer, path_attr) == ckpt_path
729+
assert getattr(trainer, "ckpt_path") == ckpt_path
731730

732731

733732
@pytest.mark.parametrize("enable_checkpointing", (False, True))
@@ -758,15 +757,14 @@ def predict_step(self, batch, *_):
758757
trainer.fit(model)
759758

760759
trainer_fn = getattr(trainer, fn)
761-
path_attr = f"{fn}{'d' if fn == 'validate' else 'ed'}_ckpt_path"
762-
assert getattr(trainer, path_attr) is None
760+
assert getattr(trainer, "ckpt_path") is None
763761

764762
if enable_checkpointing:
765763
trainer_fn(ckpt_path="best")
766-
assert getattr(trainer, path_attr) == trainer.checkpoint_callback.best_model_path
764+
assert getattr(trainer, "ckpt_path") == trainer.checkpoint_callback.best_model_path
767765

768766
trainer_fn(model, ckpt_path="best")
769-
assert getattr(trainer, path_attr) == trainer.checkpoint_callback.best_model_path
767+
assert getattr(trainer, "ckpt_path") == trainer.checkpoint_callback.best_model_path
770768
else:
771769
with pytest.raises(MisconfigurationException, match="`ModelCheckpoint` is not configured."):
772770
trainer_fn(ckpt_path="best")

0 commit comments

Comments
 (0)