Skip to content

Fix mypy in utilities.parsing #8132

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jul 7, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 50 additions & 28 deletions pytorch_lightning/utilities/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@
import types
from argparse import Namespace
from dataclasses import fields, is_dataclass
from typing import Any, Dict, Optional, Sequence, Tuple, Union
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union

from typing_extensions import Literal

import pytorch_lightning as pl
from pytorch_lightning.utilities.warnings import rank_zero_warn


Expand Down Expand Up @@ -54,10 +57,10 @@ def str_to_bool(val: str) -> bool:
>>> str_to_bool('FALSE')
False
"""
val = str_to_bool_or_str(val)
if isinstance(val, bool):
return val
raise ValueError(f'invalid truth value {val}')
val_converted = str_to_bool_or_str(val)
if isinstance(val_converted, bool):
return val_converted
raise ValueError(f'invalid truth value {val_converted}')


def str_to_bool_or_int(val: str) -> Union[bool, int, str]:
Expand All @@ -72,13 +75,13 @@ def str_to_bool_or_int(val: str) -> Union[bool, int, str]:
>>> str_to_bool_or_int("abc")
'abc'
"""
val = str_to_bool_or_str(val)
if isinstance(val, bool):
return val
val_converted = str_to_bool_or_str(val)
if isinstance(val_converted, bool):
return val_converted
try:
return int(val)
return int(val_converted)
except ValueError:
return val
return val_converted


def is_picklable(obj: object) -> bool:
Expand All @@ -91,7 +94,7 @@ def is_picklable(obj: object) -> bool:
return False


def clean_namespace(hparams):
def clean_namespace(hparams: Union[Dict[str, Any], Namespace]) -> None:
"""Removes all unpicklable entries from hparams"""

hparams_dict = hparams
Expand All @@ -105,7 +108,7 @@ def clean_namespace(hparams):
del hparams_dict[k]


def parse_class_init_keys(cls) -> Tuple[str, str, str]:
def parse_class_init_keys(cls: Type['pl.LightningModule']) -> Tuple[str, Optional[str], Optional[str]]:
"""Parse key words for standard self, *args and **kwargs

>>> class Model():
Expand All @@ -121,18 +124,22 @@ def parse_class_init_keys(cls) -> Tuple[str, str, str]:
# self is always first
n_self = init_params[0].name

def _get_first_if_any(params, param_type):
def _get_first_if_any(
params: List[inspect.Parameter],
param_type: Literal[inspect._ParameterKind.VAR_POSITIONAL, inspect._ParameterKind.VAR_KEYWORD],
) -> Optional[str]:
for p in params:
if p.kind == param_type:
return p.name
return None

n_args = _get_first_if_any(init_params, inspect.Parameter.VAR_POSITIONAL)
n_kwargs = _get_first_if_any(init_params, inspect.Parameter.VAR_KEYWORD)

return n_self, n_args, n_kwargs


def get_init_args(frame) -> dict:
def get_init_args(frame: types.FrameType) -> Dict[str, Any]:
_, _, _, local_vars = inspect.getargvalues(frame)
if '__class__' not in local_vars:
return {}
Expand All @@ -143,12 +150,18 @@ def get_init_args(frame) -> dict:
exclude_argnames = (*filtered_vars, '__class__', 'frame', 'frame_args')
# only collect variables that appear in the signature
local_args = {k: local_vars[k] for k in init_parameters.keys()}
local_args.update(local_args.get(kwargs_var, {}))
# kwargs_var might be None => raised an error by mypy
if kwargs_var:
local_args.update(local_args.get(kwargs_var, {}))
local_args = {k: v for k, v in local_args.items() if k not in exclude_argnames}
return local_args


def collect_init_args(frame, path_args: list, inside: bool = False) -> list:
def collect_init_args(
frame: types.FrameType,
path_args: List[Dict[str, Any]],
inside: bool = False,
) -> List[Dict[str, Any]]:
"""
Recursively collects the arguments passed to the child constructors in the inheritance tree.

Expand All @@ -163,6 +176,10 @@ def collect_init_args(frame, path_args: list, inside: bool = False) -> list:
most specific class in the hierarchy.
"""
_, _, _, local_vars = inspect.getargvalues(frame)
# frame.f_back must be of a type types.FrameType for get_init_args/collect_init_args due to mypy
if not isinstance(frame.f_back, types.FrameType):
return path_args

if '__class__' in local_vars:
local_args = get_init_args(frame)
# recursive update
Expand All @@ -174,7 +191,7 @@ def collect_init_args(frame, path_args: list, inside: bool = False) -> list:
return path_args


def flatten_dict(source, result=None):
def flatten_dict(source: Dict[str, Any], result: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
if result is None:
result = {}

Expand All @@ -189,7 +206,7 @@ def flatten_dict(source, result=None):

def save_hyperparameters(
obj: Any,
*args,
*args: Any,
ignore: Optional[Union[Sequence[str], str]] = None,
frame: Optional[types.FrameType] = None
) -> None:
Expand All @@ -200,7 +217,12 @@ def save_hyperparameters(
return

if not frame:
frame = inspect.currentframe().f_back
current_frame = inspect.currentframe()
# inspect.currentframe() return type is Optional[types.FrameType]: current_frame.f_back called only if available
if current_frame:
frame = current_frame.f_back
if not isinstance(frame, types.FrameType):
raise AttributeError("There is no `frame` available while being required.")

if is_dataclass(obj):
init_args = {f.name: getattr(obj, f.name) for f in fields(obj)}
Expand Down Expand Up @@ -253,16 +275,16 @@ class AttributeDict(Dict):
"my-key": 3.14
"""

def __getattr__(self, key):
def __getattr__(self, key: str) -> Optional[Any]:
try:
return self[key]
except KeyError as exp:
raise AttributeError(f'Missing attribute "{key}"') from exp

def __setattr__(self, key, val):
def __setattr__(self, key: str, val: Any) -> None:
self[key] = val

def __repr__(self):
def __repr__(self) -> str:
if not len(self):
return ""
max_key_length = max([len(str(k)) for k in self])
Expand All @@ -272,14 +294,14 @@ def __repr__(self):
return out


def _lightning_get_all_attr_holders(model, attribute):
def _lightning_get_all_attr_holders(model: 'pl.LightningModule', attribute: str) -> List[Any]:
"""
Special attribute finding for Lightning. Gets all of the objects or dicts that holds attribute.
Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule.
"""
trainer = getattr(model, 'trainer', None)

holders = []
holders: List[Any] = []

# Check if attribute in model
if hasattr(model, attribute):
Expand All @@ -297,7 +319,7 @@ def _lightning_get_all_attr_holders(model, attribute):
return holders


def _lightning_get_first_attr_holder(model, attribute):
def _lightning_get_first_attr_holder(model: 'pl.LightningModule', attribute: str) -> Optional[Any]:
"""
Special attribute finding for Lightning. Gets the object or dict that holds attribute, or None.
Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule,
Expand All @@ -310,15 +332,15 @@ def _lightning_get_first_attr_holder(model, attribute):
return holders[-1]


def lightning_hasattr(model, attribute):
def lightning_hasattr(model: 'pl.LightningModule', attribute: str) -> bool:
"""
Special hasattr for Lightning. Checks for attribute in model namespace,
the old hparams namespace/dict, and the datamodule.
"""
return _lightning_get_first_attr_holder(model, attribute) is not None


def lightning_getattr(model, attribute):
def lightning_getattr(model: 'pl.LightningModule', attribute: str) -> Optional[Any]:
"""
Special getattr for Lightning. Checks for attribute in model namespace,
the old hparams namespace/dict, and the datamodule.
Expand All @@ -340,7 +362,7 @@ def lightning_getattr(model, attribute):
return getattr(holder, attribute)


def lightning_setattr(model, attribute, value):
def lightning_setattr(model: 'pl.LightningModule', attribute: str, value: Any) -> None:
"""
Special setattr for Lightning. Checks for attribute in model namespace
and the old hparams namespace/dict.
Expand Down
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,8 @@ ignore_errors = True
ignore_errors = True
[mypy-pytorch_lightning.utilities.cli]
ignore_errors = False
[mypy-pytorch_lightning.utilities.parsing]
ignore_errors = False

# todo: add proper typing to this module...
[mypy-pl_examples.*]
Expand Down