Skip to content

Commit 25bfd06

Browse files
awaelchlilexierule
authored andcommitted
revert #9125 / fix back-compatibility with saving hparams as a whole container (#9642)
1 parent 3415323 commit 25bfd06

File tree

3 files changed

+46
-53
lines changed

3 files changed

+46
-53
lines changed

CHANGELOG.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66

77

8-
## [1.4.8] - 2021-09-21
8+
## [1.4.8] - 2021-09-22
99

1010
- Fixed error reporting in DDP process reconciliation when processes are launched by an external agent (#9389)
1111
- Added PL_RECONCILE_PROCESS environment variable to enable process reconciliation regardless of cluster environment settings (#9389)
1212
- Fixed `add_argparse_args` raising `TypeError` when args are typed as `typing.Generic` in Python 3.6 ([#9554](https://github.com/PyTorchLightning/pytorch-lightning/pull/9554))
13-
13+
- Fixed back-compatibility for saving hyperparameters from a single container and inferring its argument name by reverting [#9125](https://github.com/PyTorchLightning/pytorch-lightning/pull/9125) ([#9642](https://github.com/PyTorchLightning/pytorch-lightning/pull/9642))
1414

1515
## [1.4.7] - 2021-09-14
1616

pytorch_lightning/utilities/parsing.py

Lines changed: 36 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,8 @@
2222
from typing_extensions import Literal
2323

2424
import pytorch_lightning as pl
25-
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE
2625
from pytorch_lightning.utilities.warnings import rank_zero_warn
2726

28-
if _OMEGACONF_AVAILABLE:
29-
from omegaconf.dictconfig import DictConfig
30-
3127

3228
def str_to_bool_or_str(val: str) -> Union[str, bool]:
3329
"""Possibly convert a string representation of truth to bool.
@@ -208,57 +204,46 @@ def save_hyperparameters(
208204
obj: Any, *args: Any, ignore: Optional[Union[Sequence[str], str]] = None, frame: Optional[types.FrameType] = None
209205
) -> None:
210206
"""See :meth:`~pytorch_lightning.LightningModule.save_hyperparameters`"""
211-
hparams_container_types = [Namespace, dict]
212-
if _OMEGACONF_AVAILABLE:
213-
hparams_container_types.append(DictConfig)
214-
# empty container
207+
215208
if len(args) == 1 and not isinstance(args, str) and not args[0]:
209+
# args[0] is an empty container
216210
return
217-
# container
218-
elif len(args) == 1 and isinstance(args[0], tuple(hparams_container_types)):
219-
hp = args[0]
220-
obj._hparams_name = "hparams"
221-
obj._set_hparams(hp)
222-
obj._hparams_initial = copy.deepcopy(obj._hparams)
223-
return
224-
# non-container args parsing
211+
212+
if not frame:
213+
current_frame = inspect.currentframe()
214+
# inspect.currentframe() return type is Optional[types.FrameType]: current_frame.f_back called only if available
215+
if current_frame:
216+
frame = current_frame.f_back
217+
if not isinstance(frame, types.FrameType):
218+
raise AttributeError("There is no `frame` available while being required.")
219+
220+
if is_dataclass(obj):
221+
init_args = {f.name: getattr(obj, f.name) for f in fields(obj)}
225222
else:
226-
if not frame:
227-
current_frame = inspect.currentframe()
228-
# inspect.currentframe() return type is Optional[types.FrameType]
229-
# current_frame.f_back called only if available
230-
if current_frame:
231-
frame = current_frame.f_back
232-
if not isinstance(frame, types.FrameType):
233-
raise AttributeError("There is no `frame` available while being required.")
234-
235-
if is_dataclass(obj):
236-
init_args = {f.name: getattr(obj, f.name) for f in fields(obj)}
237-
else:
238-
init_args = get_init_args(frame)
239-
assert init_args, f"failed to inspect the obj init - {frame}"
240-
241-
if ignore is not None:
242-
if isinstance(ignore, str):
243-
ignore = [ignore]
244-
if isinstance(ignore, (list, tuple, set)):
245-
ignore = [arg for arg in ignore if isinstance(arg, str)]
246-
init_args = {k: v for k, v in init_args.items() if k not in ignore}
247-
248-
if not args:
249-
# take all arguments
250-
hp = init_args
251-
obj._hparams_name = "kwargs" if hp else None
223+
init_args = get_init_args(frame)
224+
assert init_args, "failed to inspect the obj init"
225+
226+
if ignore is not None:
227+
if isinstance(ignore, str):
228+
ignore = [ignore]
229+
if isinstance(ignore, (list, tuple)):
230+
ignore = [arg for arg in ignore if isinstance(arg, str)]
231+
init_args = {k: v for k, v in init_args.items() if k not in ignore}
232+
233+
if not args:
234+
# take all arguments
235+
hp = init_args
236+
obj._hparams_name = "kwargs" if hp else None
237+
else:
238+
# take only listed arguments in `save_hparams`
239+
isx_non_str = [i for i, arg in enumerate(args) if not isinstance(arg, str)]
240+
if len(isx_non_str) == 1:
241+
hp = args[isx_non_str[0]]
242+
cand_names = [k for k, v in init_args.items() if v == hp]
243+
obj._hparams_name = cand_names[0] if cand_names else None
252244
else:
253-
# take only listed arguments in `save_hparams`
254-
isx_non_str = [i for i, arg in enumerate(args) if not isinstance(arg, str)]
255-
if len(isx_non_str) == 1:
256-
hp = args[isx_non_str[0]]
257-
cand_names = [k for k, v in init_args.items() if v == hp]
258-
obj._hparams_name = cand_names[0] if cand_names else None
259-
else:
260-
hp = {arg: init_args[arg] for arg in args if isinstance(arg, str)}
261-
obj._hparams_name = "kwargs"
245+
hp = {arg: init_args[arg] for arg in args if isinstance(arg, str)}
246+
obj._hparams_name = "kwargs"
262247

263248
# `hparams` are expected here
264249
if hp:

tests/models/test_hparams.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -689,6 +689,14 @@ def test_empty_hparams_container(tmpdir):
689689
assert not model.hparams
690690

691691

692+
def test_hparams_name_from_container(tmpdir):
693+
"""Test that save_hyperparameters(container) captures the name of the argument correctly."""
694+
model = HparamsKwargsContainerModel(a=1, b=2)
695+
assert model._hparams_name is None
696+
model = HparamsNamespaceContainerModel(Namespace(a=1, b=2))
697+
assert model._hparams_name == "config"
698+
699+
692700
@dataclass
693701
class DataClassModel(BoringModel):
694702

0 commit comments

Comments
 (0)