|
22 | 22 | from typing_extensions import Literal
|
23 | 23 |
|
24 | 24 | import pytorch_lightning as pl
|
25 |
| -from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE |
26 | 25 | from pytorch_lightning.utilities.warnings import rank_zero_warn
|
27 | 26 |
|
28 |
| -if _OMEGACONF_AVAILABLE: |
29 |
| - from omegaconf.dictconfig import DictConfig |
30 |
| - |
31 | 27 |
|
32 | 28 | def str_to_bool_or_str(val: str) -> Union[str, bool]:
|
33 | 29 | """Possibly convert a string representation of truth to bool. Returns the input otherwise. Based on the python
|
@@ -205,57 +201,46 @@ def save_hyperparameters(
|
205 | 201 | obj: Any, *args: Any, ignore: Optional[Union[Sequence[str], str]] = None, frame: Optional[types.FrameType] = None
|
206 | 202 | ) -> None:
|
207 | 203 | """See :meth:`~pytorch_lightning.LightningModule.save_hyperparameters`"""
|
208 |
| - hparams_container_types = [Namespace, dict] |
209 |
| - if _OMEGACONF_AVAILABLE: |
210 |
| - hparams_container_types.append(DictConfig) |
211 |
| - # empty container |
| 204 | + |
212 | 205 | if len(args) == 1 and not isinstance(args, str) and not args[0]:
|
| 206 | + # args[0] is an empty container |
213 | 207 | return
|
214 |
| - # container |
215 |
| - elif len(args) == 1 and isinstance(args[0], tuple(hparams_container_types)): |
216 |
| - hp = args[0] |
217 |
| - obj._hparams_name = "hparams" |
218 |
| - obj._set_hparams(hp) |
219 |
| - obj._hparams_initial = copy.deepcopy(obj._hparams) |
220 |
| - return |
221 |
| - # non-container args parsing |
| 208 | + |
| 209 | + if not frame: |
| 210 | + current_frame = inspect.currentframe() |
| 211 | + # inspect.currentframe() return type is Optional[types.FrameType]: current_frame.f_back called only if available |
| 212 | + if current_frame: |
| 213 | + frame = current_frame.f_back |
| 214 | + if not isinstance(frame, types.FrameType): |
| 215 | + raise AttributeError("There is no `frame` available while being required.") |
| 216 | + |
| 217 | + if is_dataclass(obj): |
| 218 | + init_args = {f.name: getattr(obj, f.name) for f in fields(obj)} |
222 | 219 | else:
|
223 |
| - if not frame: |
224 |
| - current_frame = inspect.currentframe() |
225 |
| - # inspect.currentframe() return type is Optional[types.FrameType] |
226 |
| - # current_frame.f_back called only if available |
227 |
| - if current_frame: |
228 |
| - frame = current_frame.f_back |
229 |
| - if not isinstance(frame, types.FrameType): |
230 |
| - raise AttributeError("There is no `frame` available while being required.") |
231 |
| - |
232 |
| - if is_dataclass(obj): |
233 |
| - init_args = {f.name: getattr(obj, f.name) for f in fields(obj)} |
234 |
| - else: |
235 |
| - init_args = get_init_args(frame) |
236 |
| - assert init_args, f"failed to inspect the obj init - {frame}" |
237 |
| - |
238 |
| - if ignore is not None: |
239 |
| - if isinstance(ignore, str): |
240 |
| - ignore = [ignore] |
241 |
| - if isinstance(ignore, (list, tuple, set)): |
242 |
| - ignore = [arg for arg in ignore if isinstance(arg, str)] |
243 |
| - init_args = {k: v for k, v in init_args.items() if k not in ignore} |
244 |
| - |
245 |
| - if not args: |
246 |
| - # take all arguments |
247 |
| - hp = init_args |
248 |
| - obj._hparams_name = "kwargs" if hp else None |
| 220 | + init_args = get_init_args(frame) |
| 221 | + assert init_args, "failed to inspect the obj init" |
| 222 | + |
| 223 | + if ignore is not None: |
| 224 | + if isinstance(ignore, str): |
| 225 | + ignore = [ignore] |
| 226 | + if isinstance(ignore, (list, tuple)): |
| 227 | + ignore = [arg for arg in ignore if isinstance(arg, str)] |
| 228 | + init_args = {k: v for k, v in init_args.items() if k not in ignore} |
| 229 | + |
| 230 | + if not args: |
| 231 | + # take all arguments |
| 232 | + hp = init_args |
| 233 | + obj._hparams_name = "kwargs" if hp else None |
| 234 | + else: |
| 235 | + # take only listed arguments in `save_hparams` |
| 236 | + isx_non_str = [i for i, arg in enumerate(args) if not isinstance(arg, str)] |
| 237 | + if len(isx_non_str) == 1: |
| 238 | + hp = args[isx_non_str[0]] |
| 239 | + cand_names = [k for k, v in init_args.items() if v == hp] |
| 240 | + obj._hparams_name = cand_names[0] if cand_names else None |
249 | 241 | else:
|
250 |
| - # take only listed arguments in `save_hparams` |
251 |
| - isx_non_str = [i for i, arg in enumerate(args) if not isinstance(arg, str)] |
252 |
| - if len(isx_non_str) == 1: |
253 |
| - hp = args[isx_non_str[0]] |
254 |
| - cand_names = [k for k, v in init_args.items() if v == hp] |
255 |
| - obj._hparams_name = cand_names[0] if cand_names else None |
256 |
| - else: |
257 |
| - hp = {arg: init_args[arg] for arg in args if isinstance(arg, str)} |
258 |
| - obj._hparams_name = "kwargs" |
| 242 | + hp = {arg: init_args[arg] for arg in args if isinstance(arg, str)} |
| 243 | + obj._hparams_name = "kwargs" |
259 | 244 |
|
260 | 245 | # `hparams` are expected here
|
261 | 246 | if hp:
|
|
0 commit comments