Skip to content

Commit dda6f5f

Browse files
committed
Revert "Fix inspection of unspecified args for container hparams (#9125)"
This reverts commit 904dde7.
1 parent a71be50 commit dda6f5f

File tree

1 file changed

+36
-51
lines changed

1 file changed

+36
-51
lines changed

pytorch_lightning/utilities/parsing.py

Lines changed: 36 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,8 @@
2222
from typing_extensions import Literal
2323

2424
import pytorch_lightning as pl
25-
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE
2625
from pytorch_lightning.utilities.warnings import rank_zero_warn
2726

28-
if _OMEGACONF_AVAILABLE:
29-
from omegaconf.dictconfig import DictConfig
30-
3127

3228
def str_to_bool_or_str(val: str) -> Union[str, bool]:
3329
"""Possibly convert a string representation of truth to bool. Returns the input otherwise. Based on the python
@@ -205,57 +201,46 @@ def save_hyperparameters(
205201
obj: Any, *args: Any, ignore: Optional[Union[Sequence[str], str]] = None, frame: Optional[types.FrameType] = None
206202
) -> None:
207203
"""See :meth:`~pytorch_lightning.LightningModule.save_hyperparameters`"""
208-
hparams_container_types = [Namespace, dict]
209-
if _OMEGACONF_AVAILABLE:
210-
hparams_container_types.append(DictConfig)
211-
# empty container
204+
212205
if len(args) == 1 and not isinstance(args, str) and not args[0]:
206+
# args[0] is an empty container
213207
return
214-
# container
215-
elif len(args) == 1 and isinstance(args[0], tuple(hparams_container_types)):
216-
hp = args[0]
217-
obj._hparams_name = "hparams"
218-
obj._set_hparams(hp)
219-
obj._hparams_initial = copy.deepcopy(obj._hparams)
220-
return
221-
# non-container args parsing
208+
209+
if not frame:
210+
current_frame = inspect.currentframe()
211+
# inspect.currentframe() return type is Optional[types.FrameType]: current_frame.f_back called only if available
212+
if current_frame:
213+
frame = current_frame.f_back
214+
if not isinstance(frame, types.FrameType):
215+
raise AttributeError("There is no `frame` available while being required.")
216+
217+
if is_dataclass(obj):
218+
init_args = {f.name: getattr(obj, f.name) for f in fields(obj)}
222219
else:
223-
if not frame:
224-
current_frame = inspect.currentframe()
225-
# inspect.currentframe() return type is Optional[types.FrameType]
226-
# current_frame.f_back called only if available
227-
if current_frame:
228-
frame = current_frame.f_back
229-
if not isinstance(frame, types.FrameType):
230-
raise AttributeError("There is no `frame` available while being required.")
231-
232-
if is_dataclass(obj):
233-
init_args = {f.name: getattr(obj, f.name) for f in fields(obj)}
234-
else:
235-
init_args = get_init_args(frame)
236-
assert init_args, f"failed to inspect the obj init - {frame}"
237-
238-
if ignore is not None:
239-
if isinstance(ignore, str):
240-
ignore = [ignore]
241-
if isinstance(ignore, (list, tuple, set)):
242-
ignore = [arg for arg in ignore if isinstance(arg, str)]
243-
init_args = {k: v for k, v in init_args.items() if k not in ignore}
244-
245-
if not args:
246-
# take all arguments
247-
hp = init_args
248-
obj._hparams_name = "kwargs" if hp else None
220+
init_args = get_init_args(frame)
221+
assert init_args, "failed to inspect the obj init"
222+
223+
if ignore is not None:
224+
if isinstance(ignore, str):
225+
ignore = [ignore]
226+
if isinstance(ignore, (list, tuple)):
227+
ignore = [arg for arg in ignore if isinstance(arg, str)]
228+
init_args = {k: v for k, v in init_args.items() if k not in ignore}
229+
230+
if not args:
231+
# take all arguments
232+
hp = init_args
233+
obj._hparams_name = "kwargs" if hp else None
234+
else:
235+
# take only listed arguments in `save_hparams`
236+
isx_non_str = [i for i, arg in enumerate(args) if not isinstance(arg, str)]
237+
if len(isx_non_str) == 1:
238+
hp = args[isx_non_str[0]]
239+
cand_names = [k for k, v in init_args.items() if v == hp]
240+
obj._hparams_name = cand_names[0] if cand_names else None
249241
else:
250-
# take only listed arguments in `save_hparams`
251-
isx_non_str = [i for i, arg in enumerate(args) if not isinstance(arg, str)]
252-
if len(isx_non_str) == 1:
253-
hp = args[isx_non_str[0]]
254-
cand_names = [k for k, v in init_args.items() if v == hp]
255-
obj._hparams_name = cand_names[0] if cand_names else None
256-
else:
257-
hp = {arg: init_args[arg] for arg in args if isinstance(arg, str)}
258-
obj._hparams_name = "kwargs"
242+
hp = {arg: init_args[arg] for arg in args if isinstance(arg, str)}
243+
obj._hparams_name = "kwargs"
259244

260245
# `hparams` are expected here
261246
if hp:

0 commit comments

Comments
 (0)