|
22 | 22 | from typing_extensions import Literal
|
23 | 23 |
|
24 | 24 | import pytorch_lightning as pl
|
25 |
| -from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE |
26 | 25 | from pytorch_lightning.utilities.warnings import rank_zero_warn
|
27 | 26 |
|
28 |
| -if _OMEGACONF_AVAILABLE: |
29 |
| - from omegaconf.dictconfig import DictConfig |
30 |
| - |
31 | 27 |
|
32 | 28 | def str_to_bool_or_str(val: str) -> Union[str, bool]:
|
33 | 29 | """Possibly convert a string representation of truth to bool.
|
@@ -208,57 +204,46 @@ def save_hyperparameters(
|
208 | 204 | obj: Any, *args: Any, ignore: Optional[Union[Sequence[str], str]] = None, frame: Optional[types.FrameType] = None
|
209 | 205 | ) -> None:
|
210 | 206 | """See :meth:`~pytorch_lightning.LightningModule.save_hyperparameters`"""
|
211 |
| - hparams_container_types = [Namespace, dict] |
212 |
| - if _OMEGACONF_AVAILABLE: |
213 |
| - hparams_container_types.append(DictConfig) |
214 |
| - # empty container |
| 207 | + |
215 | 208 | if len(args) == 1 and not isinstance(args, str) and not args[0]:
|
| 209 | + # args[0] is an empty container |
216 | 210 | return
|
217 |
| - # container |
218 |
| - elif len(args) == 1 and isinstance(args[0], tuple(hparams_container_types)): |
219 |
| - hp = args[0] |
220 |
| - obj._hparams_name = "hparams" |
221 |
| - obj._set_hparams(hp) |
222 |
| - obj._hparams_initial = copy.deepcopy(obj._hparams) |
223 |
| - return |
224 |
| - # non-container args parsing |
| 211 | + |
| 212 | + if not frame: |
| 213 | + current_frame = inspect.currentframe() |
| 214 | + # inspect.currentframe() return type is Optional[types.FrameType]: current_frame.f_back called only if available |
| 215 | + if current_frame: |
| 216 | + frame = current_frame.f_back |
| 217 | + if not isinstance(frame, types.FrameType): |
| 218 | + raise AttributeError("There is no `frame` available while being required.") |
| 219 | + |
| 220 | + if is_dataclass(obj): |
| 221 | + init_args = {f.name: getattr(obj, f.name) for f in fields(obj)} |
225 | 222 | else:
|
226 |
| - if not frame: |
227 |
| - current_frame = inspect.currentframe() |
228 |
| - # inspect.currentframe() return type is Optional[types.FrameType] |
229 |
| - # current_frame.f_back called only if available |
230 |
| - if current_frame: |
231 |
| - frame = current_frame.f_back |
232 |
| - if not isinstance(frame, types.FrameType): |
233 |
| - raise AttributeError("There is no `frame` available while being required.") |
234 |
| - |
235 |
| - if is_dataclass(obj): |
236 |
| - init_args = {f.name: getattr(obj, f.name) for f in fields(obj)} |
237 |
| - else: |
238 |
| - init_args = get_init_args(frame) |
239 |
| - assert init_args, f"failed to inspect the obj init - {frame}" |
240 |
| - |
241 |
| - if ignore is not None: |
242 |
| - if isinstance(ignore, str): |
243 |
| - ignore = [ignore] |
244 |
| - if isinstance(ignore, (list, tuple, set)): |
245 |
| - ignore = [arg for arg in ignore if isinstance(arg, str)] |
246 |
| - init_args = {k: v for k, v in init_args.items() if k not in ignore} |
247 |
| - |
248 |
| - if not args: |
249 |
| - # take all arguments |
250 |
| - hp = init_args |
251 |
| - obj._hparams_name = "kwargs" if hp else None |
| 223 | + init_args = get_init_args(frame) |
| 224 | + assert init_args, "failed to inspect the obj init" |
| 225 | + |
| 226 | + if ignore is not None: |
| 227 | + if isinstance(ignore, str): |
| 228 | + ignore = [ignore] |
| 229 | + if isinstance(ignore, (list, tuple)): |
| 230 | + ignore = [arg for arg in ignore if isinstance(arg, str)] |
| 231 | + init_args = {k: v for k, v in init_args.items() if k not in ignore} |
| 232 | + |
| 233 | + if not args: |
| 234 | + # take all arguments |
| 235 | + hp = init_args |
| 236 | + obj._hparams_name = "kwargs" if hp else None |
| 237 | + else: |
| 238 | + # take only listed arguments in `save_hparams` |
| 239 | + isx_non_str = [i for i, arg in enumerate(args) if not isinstance(arg, str)] |
| 240 | + if len(isx_non_str) == 1: |
| 241 | + hp = args[isx_non_str[0]] |
| 242 | + cand_names = [k for k, v in init_args.items() if v == hp] |
| 243 | + obj._hparams_name = cand_names[0] if cand_names else None |
252 | 244 | else:
|
253 |
| - # take only listed arguments in `save_hparams` |
254 |
| - isx_non_str = [i for i, arg in enumerate(args) if not isinstance(arg, str)] |
255 |
| - if len(isx_non_str) == 1: |
256 |
| - hp = args[isx_non_str[0]] |
257 |
| - cand_names = [k for k, v in init_args.items() if v == hp] |
258 |
| - obj._hparams_name = cand_names[0] if cand_names else None |
259 |
| - else: |
260 |
| - hp = {arg: init_args[arg] for arg in args if isinstance(arg, str)} |
261 |
| - obj._hparams_name = "kwargs" |
| 245 | + hp = {arg: init_args[arg] for arg in args if isinstance(arg, str)} |
| 246 | + obj._hparams_name = "kwargs" |
262 | 247 |
|
263 | 248 | # `hparams` are expected here
|
264 | 249 | if hp:
|
|
0 commit comments