Skip to content

Allow access to ckpt_path within context of fit() #11696

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Deprecated `Trainer.{validated,tested,predicted}_ckpt_path` and replaced with read-only property `Trainer.ckpt_path` set when checkpoints loaded via `Trainer.{fit,validate,test,predict}` ([#11696](https://github.com/PyTorchLightning/pytorch-lightning/pull/11696))


- Set the `prog_bar` flag to False in `LightningModule.log_grad_norm` ([#11472](https://github.com/PyTorchLightning/pytorch-lightning/pull/11472))


Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/callbacks/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
if self._restarting:
named_parameters = dict(pl_module.named_parameters())
for opt_idx, optimizer in enumerate(trainer.optimizers):
param_groups = self.__apply_mapping_to_param_groups(
param_groups = self._apply_mapping_to_param_groups(
self._internal_optimizer_metadata[opt_idx], named_parameters
)
optimizer.param_groups = param_groups
Expand Down Expand Up @@ -244,7 +244,7 @@ def on_before_accelerator_backend_setup(self, trainer: "pl.Trainer", pl_module:
self.freeze_before_training(pl_module)

@staticmethod
def __apply_mapping_to_param_groups(param_groups: List[Dict[str, Any]], mapping: dict) -> List[Dict[str, Any]]:
def _apply_mapping_to_param_groups(param_groups: List[Dict[str, Any]], mapping: dict) -> List[Dict[str, Any]]:
output = []
for g in param_groups:
# skip params to save memory
Expand All @@ -262,13 +262,13 @@ def _store(
) -> None:
mapping = {p: n for n, p in pl_module.named_parameters()}
if opt_idx not in self._internal_optimizer_metadata:
self._internal_optimizer_metadata[opt_idx] = self.__apply_mapping_to_param_groups(
self._internal_optimizer_metadata[opt_idx] = self._apply_mapping_to_param_groups(
current_param_groups, mapping
)
elif num_param_groups != len(current_param_groups):
# save new param_groups possibly created by the users.
self._internal_optimizer_metadata[opt_idx].extend(
self.__apply_mapping_to_param_groups(current_param_groups[num_param_groups:], mapping)
self._apply_mapping_to_param_groups(current_param_groups[num_param_groups:], mapping)
)

def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
Expand Down
103 changes: 92 additions & 11 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,10 +480,14 @@ def __init__(
# default .predict() loop
self.predict_loop = PredictionLoop()

# .validate() and .test() set this when they load a checkpoint
self.validated_ckpt_path: Optional[str] = None
self.tested_ckpt_path: Optional[str] = None
self.predicted_ckpt_path: Optional[str] = None
# set when a checkpoint is loaded via `Trainer.{fit,validate,test,predict}`.
self._ckpt_path: Optional[str] = None

# .validate(), predict() and .test() set these when they load a checkpoint. They will be removed in favor of
# the unified read-only `Trainer.ckpt_path` attribute in v1.8
self._validated_ckpt_path: Optional[str] = None # TODO: remove in v1.8
self._tested_ckpt_path: Optional[str] = None # TODO: remove in v1.8
self._predicted_ckpt_path: Optional[str] = None # TODO: remove in v1.8

# todo: remove in v1.7
self._weights_summary: Optional[str] = None
Expand Down Expand Up @@ -758,7 +762,10 @@ def _fit_impl(

# TODO: ckpt_path only in v2.0
ckpt_path = ckpt_path or self.resume_from_checkpoint
results = self._run(model, ckpt_path=ckpt_path)
self._ckpt_path = self.__set_ckpt_path(
ckpt_path, model_provided=model, model_connected=self.lightning_module is not None
)
results = self._run(model, ckpt_path=self.ckpt_path)

assert self.state.stopped
self.training = False
Expand Down Expand Up @@ -837,12 +844,14 @@ def _validate_impl(
# links data to the trainer
self._data_connector.attach_data(model, val_dataloaders=dataloaders, datamodule=datamodule)

self.validated_ckpt_path = self.__set_ckpt_path(
self._ckpt_path = self.__set_ckpt_path(
ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
)

self._validated_ckpt_path = self.ckpt_path # TODO: remove in v1.8

# run validate
results = self._run(model, ckpt_path=self.validated_ckpt_path)
results = self._run(model, ckpt_path=self.ckpt_path)

assert self.state.stopped
self.validating = False
Expand Down Expand Up @@ -923,12 +932,14 @@ def _test_impl(
# links data to the trainer
self._data_connector.attach_data(model, test_dataloaders=dataloaders, datamodule=datamodule)

self.tested_ckpt_path = self.__set_ckpt_path(
self._ckpt_path = self.__set_ckpt_path(
ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
)

self._tested_ckpt_path = self.ckpt_path # TODO: remove in v1.8

# run test
results = self._run(model, ckpt_path=self.tested_ckpt_path)
results = self._run(model, ckpt_path=self.ckpt_path)

assert self.state.stopped
self.testing = False
Expand Down Expand Up @@ -1009,11 +1020,13 @@ def _predict_impl(
# links data to the trainer
self._data_connector.attach_data(model, predict_dataloaders=dataloaders, datamodule=datamodule)

self.predicted_ckpt_path = self.__set_ckpt_path(
self._ckpt_path = self.__set_ckpt_path(
ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
)

results = self._run(model, ckpt_path=self.predicted_ckpt_path)
self._predicted_ckpt_path = self.ckpt_path # TODO: remove in v1.8

results = self._run(model, ckpt_path=self.ckpt_path)

assert self.state.stopped
self.predicting = False
Expand Down Expand Up @@ -2217,6 +2230,74 @@ def resume_from_checkpoint(self) -> Optional[Union[str, Path]]:

return resume_from_checkpoint

@property
def ckpt_path(self) -> Optional[str]:
"""Set to the path/URL of checkpoints loaded via :meth:`~pytorch_lightning.trainer.trainer.Trainer.fit`,
:meth:`~pytorch_lightning.trainer.trainer.Trainer.validate`,
:meth:`~pytorch_lightning.trainer.trainer.Trainer.test`, or
:meth:`~pytorch_lightning.trainer.trainer.Trainer.predict`. ``None`` otherwise."""
return self._ckpt_path

@property
def validated_ckpt_path(self) -> Optional[str]:
rank_zero_deprecation(
"The `Trainer.validated_ckpt_path` attribute was deprecated in v1.6 and will be removed in v1.8. The"
" path of checkpoints loaded via `Trainer.{fit,validate,test,predict}` should be accessed via"
" `Trainer.ckpt_path` instead.",
stacklevel=5,
)
return self._validated_ckpt_path

@validated_ckpt_path.setter
def validated_ckpt_path(self, ckpt_path: Optional[str]) -> None:
rank_zero_deprecation(
"The `Trainer.validated_ckpt_path` attribute was deprecated in v1.6 and will be removed in v1.8. The"
" path of checkpoints loaded via `Trainer.{fit,validate,test,predict}` should be accessed via the read-only"
" `Trainer.ckpt_path`.",
stacklevel=5,
)
self._validated_ckpt_path = ckpt_path

@property
def tested_ckpt_path(self) -> Optional[str]:
rank_zero_deprecation(
"The `Trainer.tested_ckpt_path` attribute was deprecated in v1.6 and will be removed in v1.8. The"
" path of checkpoints loaded via `Trainer.{fit,validate,test,predict}` should be accessed via"
" `Trainer.ckpt_path` instead.",
stacklevel=5,
)
return self._tested_ckpt_path

@tested_ckpt_path.setter
def tested_ckpt_path(self, ckpt_path: Optional[str]) -> None:
rank_zero_deprecation(
"The `Trainer.tested_ckpt_path` attribute was deprecated in v1.6 and will be removed in v1.8. The"
" path of checkpoints loaded via `Trainer.{fit,validate,test,predict}` should be accessed via the read-only"
" `Trainer.ckpt_path` instead.",
stacklevel=5,
)
self._tested_ckpt_path = ckpt_path

@property
def predicted_ckpt_path(self) -> Optional[str]:
rank_zero_deprecation(
"The `Trainer.predicted_ckpt_path` attribute was deprecated in v1.6 and will be removed in v1.8. The"
" path of checkpoints loaded via `Trainer.{fit,validate,test,predict}` should be accessed via"
" `Trainer.ckpt_path` instead.",
stacklevel=5,
)
return self._predicted_ckpt_path

@predicted_ckpt_path.setter
def predicted_ckpt_path(self, ckpt_path: Optional[str]) -> None:
rank_zero_deprecation(
"The `Trainer.predicted_ckpt_path` attribute was deprecated in v1.6 and will be removed in v1.8. The"
" path of checkpoints loaded via `Trainer.{fit,validate,test,predict}` should be accessed via the read-only"
" `Trainer.ckpt_path` instead.",
stacklevel=5,
)
self._predicted_ckpt_path = ckpt_path

def save_checkpoint(self, filepath: _PATH, weights_only: bool = False) -> None:
r"""
Runs routine to create a checkpoint.
Expand Down
50 changes: 26 additions & 24 deletions tests/callbacks/test_finetuning_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,34 +287,36 @@ def configure_optimizers(self):
trainer.fit(model)


def test_complex_nested_model():
"""Test flattening, freezing, and thawing of models which contain parent (non-leaf) modules with parameters
directly themselves rather than exclusively their submodules containing parameters."""
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, 3)
self.act = nn.ReLU()
self.bn = nn.BatchNorm2d(out_channels)

class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, 3)
self.act = nn.ReLU()
self.bn = nn.BatchNorm2d(out_channels)
def forward(self, x):
x = self.conv(x)
x = self.act(x)
return self.bn(x)

def forward(self, x):
x = self.conv(x)
x = self.act(x)
return self.bn(x)

class ConvBlockParam(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.module_dict = nn.ModuleDict({"conv": nn.Conv2d(in_channels, out_channels, 3), "act": nn.ReLU()})
# add trivial test parameter to convblock to validate parent (non-leaf) module parameter handling
self.parent_param = nn.Parameter(torch.zeros((1), dtype=torch.float))
self.bn = nn.BatchNorm2d(out_channels)
class ConvBlockParam(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.module_dict = nn.ModuleDict({"conv": nn.Conv2d(in_channels, out_channels, 3), "act": nn.ReLU()})
# add trivial test parameter to convblock to validate parent (non-leaf) module parameter handling
self.parent_param = nn.Parameter(torch.zeros((1), dtype=torch.float))
self.bn = nn.BatchNorm2d(out_channels)

def forward(self, x):
x = self.module_dict["conv"](x)
x = self.module_dict["act"](x)
return self.bn(x)
def forward(self, x):
x = self.module_dict["conv"](x)
x = self.module_dict["act"](x)
return self.bn(x)


def test_complex_nested_model():
"""Test flattening, freezing, and thawing of models which contain parent (non-leaf) modules with parameters
directly themselves rather than exclusively their submodules containing parameters."""

model = nn.Sequential(
OrderedDict(
Expand Down
10 changes: 10 additions & 0 deletions tests/deprecated_api/test_remove_1-8.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,16 @@ def test_v1_8_0_trainer_verbose_evaluate():
trainer.verbose_evaluate = False


@pytest.mark.parametrize("fn_prefix", ["validated", "tested", "predicted"], ids=["validated", "tested", "predicted"])
def test_v1_8_0_trainer_ckpt_path_attributes(fn_prefix: str):
test_attr = f"{fn_prefix}_ckpt_path"
trainer = Trainer()
with pytest.deprecated_call(match=f"{test_attr}` attribute was deprecated in v1.6 and will be removed in v1.8"):
_ = getattr(trainer, test_attr)
with pytest.deprecated_call(match=f"{test_attr}` attribute was deprecated in v1.6 and will be removed in v1.8"):
setattr(trainer, test_attr, "v")


def test_v1_8_0_deprecated_trainer_should_rank_save_checkpoint(tmpdir):
trainer = Trainer()
with pytest.deprecated_call(
Expand Down
22 changes: 10 additions & 12 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,8 +686,7 @@ def predict_step(self, batch, *_):
trainer.fit(model)

trainer_fn = getattr(trainer, fn)
path_attr = f"{fn}{'d' if fn == 'validate' else 'ed'}_ckpt_path"
assert getattr(trainer, path_attr) is None
assert getattr(trainer, "ckpt_path") is None

if ckpt_path == "best":
# ckpt_path is 'best', meaning we load the best weights
Expand All @@ -698,20 +697,20 @@ def predict_step(self, batch, *_):
trainer_fn(model, ckpt_path=ckpt_path)
else:
trainer_fn(ckpt_path=ckpt_path)
assert getattr(trainer, path_attr) == trainer.checkpoint_callback.best_model_path
assert getattr(trainer, "ckpt_path") == trainer.checkpoint_callback.best_model_path

trainer_fn(model, ckpt_path=ckpt_path)
assert getattr(trainer, path_attr) == trainer.checkpoint_callback.best_model_path
assert getattr(trainer, "ckpt_path") == trainer.checkpoint_callback.best_model_path
elif ckpt_path is None:
# ckpt_path is None, meaning we don't load any checkpoints and use the provided model
trainer_fn(model, ckpt_path=ckpt_path)
assert getattr(trainer, path_attr) is None
assert getattr(trainer, "ckpt_path") is None

if save_top_k > 0:
# ckpt_path is None with no model provided means load the best weights
with pytest.warns(UserWarning, match="The best model of the previous `fit` call will be used"):
trainer_fn(ckpt_path=ckpt_path)
assert getattr(trainer, path_attr) == trainer.checkpoint_callback.best_model_path
assert getattr(trainer, "ckpt_path") == trainer.checkpoint_callback.best_model_path
else:
# specific checkpoint, pick one from saved ones
if save_top_k == 0:
Expand All @@ -724,10 +723,10 @@ def predict_step(self, batch, *_):
].absolute()
)
trainer_fn(ckpt_path=ckpt_path)
assert getattr(trainer, path_attr) == ckpt_path
assert getattr(trainer, "ckpt_path") == ckpt_path

trainer_fn(model, ckpt_path=ckpt_path)
assert getattr(trainer, path_attr) == ckpt_path
assert getattr(trainer, "ckpt_path") == ckpt_path


@pytest.mark.parametrize("enable_checkpointing", (False, True))
Expand Down Expand Up @@ -758,15 +757,14 @@ def predict_step(self, batch, *_):
trainer.fit(model)

trainer_fn = getattr(trainer, fn)
path_attr = f"{fn}{'d' if fn == 'validate' else 'ed'}_ckpt_path"
assert getattr(trainer, path_attr) is None
assert getattr(trainer, "ckpt_path") is None

if enable_checkpointing:
trainer_fn(ckpt_path="best")
assert getattr(trainer, path_attr) == trainer.checkpoint_callback.best_model_path
assert getattr(trainer, "ckpt_path") == trainer.checkpoint_callback.best_model_path

trainer_fn(model, ckpt_path="best")
assert getattr(trainer, path_attr) == trainer.checkpoint_callback.best_model_path
assert getattr(trainer, "ckpt_path") == trainer.checkpoint_callback.best_model_path
else:
with pytest.raises(MisconfigurationException, match="`ModelCheckpoint` is not configured."):
trainer_fn(ckpt_path="best")
Expand Down