Skip to content

Fix inspection of unspecified args for container hparams #9125

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Sep 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed error handling in DDP process reconciliation when `_sync_dir` was not initialized ([#9267](https://github.com/PyTorchLightning/pytorch-lightning/pull/9267))


- Fixed inspection of other args when a container is specified in `save_hyperparameters` ([#9125](https://github.com/PyTorchLightning/pytorch-lightning/pull/9125))


- Fixed `move_metrics_to_cpu` moving the loss on cpu while training on device ([#9308](https://github.com/PyTorchLightning/pytorch-lightning/pull/9308))


Expand Down
87 changes: 51 additions & 36 deletions pytorch_lightning/utilities/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,12 @@
from typing_extensions import Literal

import pytorch_lightning as pl
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE
from pytorch_lightning.utilities.warnings import rank_zero_warn

if _OMEGACONF_AVAILABLE:
from omegaconf.dictconfig import DictConfig


def str_to_bool_or_str(val: str) -> Union[str, bool]:
"""Possibly convert a string representation of truth to bool.
Expand Down Expand Up @@ -204,46 +208,57 @@ def save_hyperparameters(
obj: Any, *args: Any, ignore: Optional[Union[Sequence[str], str]] = None, frame: Optional[types.FrameType] = None
) -> None:
"""See :meth:`~pytorch_lightning.LightningModule.save_hyperparameters`"""

hparams_container_types = [Namespace, dict]
if _OMEGACONF_AVAILABLE:
hparams_container_types.append(DictConfig)
# empty container
if len(args) == 1 and not isinstance(args, str) and not args[0]:
# args[0] is an empty container
return

if not frame:
current_frame = inspect.currentframe()
# inspect.currentframe() return type is Optional[types.FrameType]: current_frame.f_back called only if available
if current_frame:
frame = current_frame.f_back
if not isinstance(frame, types.FrameType):
raise AttributeError("There is no `frame` available while being required.")

if is_dataclass(obj):
init_args = {f.name: getattr(obj, f.name) for f in fields(obj)}
else:
init_args = get_init_args(frame)
assert init_args, "failed to inspect the obj init"

if ignore is not None:
if isinstance(ignore, str):
ignore = [ignore]
if isinstance(ignore, (list, tuple)):
ignore = [arg for arg in ignore if isinstance(arg, str)]
init_args = {k: v for k, v in init_args.items() if k not in ignore}

if not args:
# take all arguments
hp = init_args
obj._hparams_name = "kwargs" if hp else None
# container
elif len(args) == 1 and isinstance(args[0], tuple(hparams_container_types)):
hp = args[0]
obj._hparams_name = "hparams"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here it is immediately assume that the name of the arg was "hparams". previously, the logic below was checking for matching argument names in the signature.

While this old way of passing in hyperparameters is not our recommended way, we have not officially dropped backward compatibility for this and therefore this was an unwanted breaking change imo.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See #9631 for an example of what was working previously.

obj._set_hparams(hp)
obj._hparams_initial = copy.deepcopy(obj._hparams)
return
# non-container args parsing
else:
# take only listed arguments in `save_hparams`
isx_non_str = [i for i, arg in enumerate(args) if not isinstance(arg, str)]
if len(isx_non_str) == 1:
hp = args[isx_non_str[0]]
cand_names = [k for k, v in init_args.items() if v == hp]
obj._hparams_name = cand_names[0] if cand_names else None
if not frame:
current_frame = inspect.currentframe()
# inspect.currentframe() return type is Optional[types.FrameType]
# current_frame.f_back called only if available
if current_frame:
frame = current_frame.f_back
if not isinstance(frame, types.FrameType):
raise AttributeError("There is no `frame` available while being required.")

if is_dataclass(obj):
init_args = {f.name: getattr(obj, f.name) for f in fields(obj)}
else:
init_args = get_init_args(frame)
assert init_args, f"failed to inspect the obj init - {frame}"

if ignore is not None:
if isinstance(ignore, str):
ignore = [ignore]
if isinstance(ignore, (list, tuple, set)):
ignore = [arg for arg in ignore if isinstance(arg, str)]
init_args = {k: v for k, v in init_args.items() if k not in ignore}

if not args:
# take all arguments
hp = init_args
obj._hparams_name = "kwargs" if hp else None
else:
hp = {arg: init_args[arg] for arg in args if isinstance(arg, str)}
obj._hparams_name = "kwargs"
# take only listed arguments in `save_hparams`
isx_non_str = [i for i, arg in enumerate(args) if not isinstance(arg, str)]
if len(isx_non_str) == 1:
hp = args[isx_non_str[0]]
cand_names = [k for k, v in init_args.items() if v == hp]
obj._hparams_name = cand_names[0] if cand_names else None
else:
hp = {arg: init_args[arg] for arg in args if isinstance(arg, str)}
obj._hparams_name = "kwargs"

# `hparams` are expected here
if hp:
Expand Down
26 changes: 22 additions & 4 deletions tests/core/test_datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pickle
from argparse import ArgumentParser
from argparse import ArgumentParser, Namespace
from dataclasses import dataclass
from typing import Any, Dict
from unittest import mock
from unittest.mock import call, PropertyMock

import pytest
import torch
from omegaconf import OmegaConf

from pytorch_lightning import LightningDataModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
Expand Down Expand Up @@ -528,16 +529,33 @@ def test_dm_init_from_datasets_dataloaders(iterable):
)


class DataModuleWithHparams(LightningDataModule):
# all args
class DataModuleWithHparams_0(LightningDataModule):
def __init__(self, arg0, arg1, kwarg0=None):
super().__init__()
self.save_hyperparameters()


def test_simple_hyperparameters_saving():
data = DataModuleWithHparams(10, "foo", kwarg0="bar")
# single arg
class DataModuleWithHparams_1(LightningDataModule):
def __init__(self, arg0, *args, **kwargs):
super().__init__()
self.save_hyperparameters(arg0)


def test_hyperparameters_saving():
data = DataModuleWithHparams_0(10, "foo", kwarg0="bar")
assert data.hparams == AttributeDict({"arg0": 10, "arg1": "foo", "kwarg0": "bar"})

data = DataModuleWithHparams_1(Namespace(**{"hello": "world"}), "foo", kwarg0="bar")
assert data.hparams == AttributeDict({"hello": "world"})

data = DataModuleWithHparams_1({"hello": "world"}, "foo", kwarg0="bar")
assert data.hparams == AttributeDict({"hello": "world"})

data = DataModuleWithHparams_1(OmegaConf.create({"hello": "world"}), "foo", kwarg0="bar")
assert data.hparams == OmegaConf.create({"hello": "world"})


def test_define_as_dataclass():
# makes sure that no functionality is broken and the user can still manually make
Expand Down