Skip to content

Fix mypy errors attributed to pytorch_lightning.core.datamodule #13693

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 43 commits into from
Aug 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
43c2ad2
remove module from pyproject.toml for ci code-checks
jxtngx Jul 17, 2022
9686043
update
jxtngx Jul 17, 2022
bd6fb0f
update return type
jxtngx Jul 17, 2022
5db82b4
update for LightningDataModule codeq
jxtngx Jul 17, 2022
cba2910
Merge branch 'Lightning-AI:master' into codeq/datamodule
jxtngx Jul 23, 2022
de0f36e
Merge branch 'master' into codeq/datamodule
jxtngx Jul 23, 2022
16a279f
Merge branch 'master' into codeq/datamodule
jxtngx Jul 25, 2022
90aa6c9
Merge branch 'master' into codeq/datamodule
jxtngx Jul 26, 2022
cbae81b
Merge branch 'master' into codeq/datamodule
jxtngx Jul 28, 2022
089981a
update
jxtngx Jul 28, 2022
6087294
Merge branch 'master' into codeq/datamodule
jxtngx Jul 28, 2022
049d1c1
update
jxtngx Jul 29, 2022
2f6e162
update
jxtngx Jul 29, 2022
ec292c3
Merge branch 'master' into codeq/datamodule
jxtngx Aug 1, 2022
b189f22
Merge branch 'master' into codeq/datamodule
jxtngx Aug 1, 2022
2d8adcc
Merge branch 'master' into codeq/datamodule
jxtngx Aug 1, 2022
517576d
update
jxtngx Aug 2, 2022
ae32bc1
Merge branch 'master' into codeq/datamodule
jxtngx Aug 2, 2022
020872d
commit suggestion
jxtngx Aug 3, 2022
850263e
update
jxtngx Aug 3, 2022
c9edb9d
Merge branch 'master' into codeq/datamodule
jxtngx Aug 3, 2022
6a45961
Merge branch 'master' into codeq/datamodule
jxtngx Aug 5, 2022
f3775d8
Merge branch 'master' into codeq/datamodule
jxtngx Aug 5, 2022
eb8ecb9
resolve merge conflicts
jxtngx Aug 8, 2022
3ae9e36
update
jxtngx Aug 8, 2022
6e2062e
Merge branch 'master' into codeq/datamodule
jxtngx Aug 9, 2022
92197a0
Merge branch 'master' into codeq/datamodule
jxtngx Aug 9, 2022
2585d0a
self review
rohitgr7 Aug 9, 2022
ab2cf24
fix
rohitgr7 Aug 9, 2022
8af819b
Merge branch 'master' into codeq/datamodule
jxtngx Aug 10, 2022
cc55ab9
update
jxtngx Aug 10, 2022
a08f574
Merge branch 'master' into codeq/datamodule
jxtngx Aug 12, 2022
7d56668
Merge branch 'master' into codeq/datamodule
carmocca Aug 22, 2022
28fae30
fix introduced mypy error
Aug 22, 2022
e0d25ec
fix docstring
Aug 22, 2022
402d025
Merge branch 'master' into codeq/datamodule
otaj Aug 23, 2022
e7cef8a
Merge branch 'master' into codeq/datamodule
otaj Aug 24, 2022
d53275b
Merge branch 'master' into codeq/datamodule
otaj Aug 25, 2022
5e5629c
Merge branch 'master' into codeq/datamodule
rohitgr7 Aug 25, 2022
4053b7e
Merge branch 'master' into codeq/datamodule
otaj Aug 26, 2022
3ee9dde
merge master
Aug 26, 2022
16528aa
merge master
Aug 26, 2022
cc11363
Merge branch 'master' into codeq/datamodule
Borda Aug 26, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ warn_no_return = "False"
# mypy --no-error-summary 2>&1 | tr ':' ' ' | awk '{print $1}' | sort | uniq | sed 's/\.py//g; s|src/||g; s|\/|\.|g' | xargs -I {} echo '"{}",'
module = [
"pytorch_lightning.callbacks.progress.rich_progress",
"pytorch_lightning.core.datamodule",
"pytorch_lightning.profilers.base",
"pytorch_lightning.profilers.pytorch",
"pytorch_lightning.strategies.sharded",
Expand Down
67 changes: 48 additions & 19 deletions src/pytorch_lightning/core/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,17 @@

from torch.utils.data import DataLoader, Dataset, IterableDataset

import pytorch_lightning as pl
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks
from pytorch_lightning.core.mixins import HyperparametersMixin
from pytorch_lightning.core.saving import _load_from_checkpoint
from pytorch_lightning.utilities.argparse import add_argparse_args, from_argparse_args, get_init_arguments_and_types
from pytorch_lightning.utilities.types import _PATH
from pytorch_lightning.utilities.argparse import (
add_argparse_args,
from_argparse_args,
get_init_arguments_and_types,
parse_argparser,
)
from pytorch_lightning.utilities.types import _ADD_ARGPARSE_RETURN, _PATH, EVAL_DATALOADERS, TRAIN_DATALOADERS


class LightningDataModule(CheckpointHooks, DataHooks, HyperparametersMixin):
Expand Down Expand Up @@ -54,7 +60,7 @@ def teardown(self):
# called on every process in DDP
"""

name: str = ...
name: Optional[str] = None
CHECKPOINT_HYPER_PARAMS_KEY = "datamodule_hyper_parameters"
CHECKPOINT_HYPER_PARAMS_NAME = "datamodule_hparams_name"
CHECKPOINT_HYPER_PARAMS_TYPE = "datamodule_hparams_type"
Expand All @@ -65,7 +71,7 @@ def __init__(self) -> None:
self.trainer = None

@classmethod
def add_argparse_args(cls, parent_parser: ArgumentParser, **kwargs) -> ArgumentParser:
def add_argparse_args(cls, parent_parser: ArgumentParser, **kwargs: Any) -> _ADD_ARGPARSE_RETURN:
"""Extends existing argparse by default `LightningDataModule` attributes.

Example::
Expand All @@ -76,7 +82,9 @@ def add_argparse_args(cls, parent_parser: ArgumentParser, **kwargs) -> ArgumentP
return add_argparse_args(cls, parent_parser, **kwargs)

@classmethod
def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs):
def from_argparse_args(
cls, args: Union[Namespace, ArgumentParser], **kwargs: Any
) -> Union["pl.LightningDataModule", "pl.Trainer"]:
"""Create an instance from CLI arguments.

Args:
Expand All @@ -91,6 +99,10 @@ def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs):
"""
return from_argparse_args(cls, args, **kwargs)

@classmethod
def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namespace:
return parse_argparser(cls, arg_parser)

@classmethod
def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]:
r"""Scans the DataModule signature and returns argument names, types and default values.
Expand All @@ -101,6 +113,15 @@ def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]:
"""
return get_init_arguments_and_types(cls)

@classmethod
def get_deprecated_arg_names(cls) -> List:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this added?

Copy link
Contributor Author

@jxtngx jxtngx Aug 3, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@carmocca mypy raises a union-attr error, as it is necessary in utilities.argparse.add_argparse_args; however, get_deprecated_arg_names isn't a function in utilities.argparse, it is a class method of Trainer.

I think the more correct thing to do, is to move the implementation of get_deprecated_arg_names to utilities.argparse, and then update core.datamodule and trainer.trainer accordingly ... do you agree?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. But let's do that in a PR before this one or a PR after.

"""Returns a list with deprecated DataModule arguments."""
depr_arg_names: List[str] = []
for name, val in cls.__dict__.items():
if name.startswith("DEPRECATED") and isinstance(val, (tuple, list)):
depr_arg_names.extend(val)
return depr_arg_names

@classmethod
def from_datasets(
cls,
Expand All @@ -111,7 +132,7 @@ def from_datasets(
batch_size: int = 1,
num_workers: int = 0,
**datamodule_kwargs: Any,
):
) -> "LightningDataModule":
r"""
Create an instance from torch.utils.data.Dataset.

Expand All @@ -132,24 +153,32 @@ def dataloader(ds: Dataset, shuffle: bool = False) -> DataLoader:
shuffle &= not isinstance(ds, IterableDataset)
return DataLoader(ds, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True)

def train_dataloader():
def train_dataloader() -> TRAIN_DATALOADERS:
assert train_dataset

if isinstance(train_dataset, Mapping):
return {key: dataloader(ds, shuffle=True) for key, ds in train_dataset.items()}
if isinstance(train_dataset, Sequence):
return [dataloader(ds, shuffle=True) for ds in train_dataset]
return dataloader(train_dataset, shuffle=True)

def val_dataloader():
def val_dataloader() -> EVAL_DATALOADERS:
assert val_dataset

if isinstance(val_dataset, Sequence):
return [dataloader(ds) for ds in val_dataset]
return dataloader(val_dataset)

def test_dataloader():
def test_dataloader() -> EVAL_DATALOADERS:
assert test_dataset

if isinstance(test_dataset, Sequence):
return [dataloader(ds) for ds in test_dataset]
return dataloader(test_dataset)

def predict_dataloader():
def predict_dataloader() -> EVAL_DATALOADERS:
assert predict_dataset

if isinstance(predict_dataset, Sequence):
return [dataloader(ds) for ds in predict_dataset]
return dataloader(predict_dataset)
Expand All @@ -160,19 +189,19 @@ def predict_dataloader():
if accepts_kwargs:
special_kwargs = candidate_kwargs
else:
accepted_params = set(accepted_params)
accepted_params.discard("self")
special_kwargs = {k: v for k, v in candidate_kwargs.items() if k in accepted_params}
accepted_param_names = set(accepted_params)
accepted_param_names.discard("self")
special_kwargs = {k: v for k, v in candidate_kwargs.items() if k in accepted_param_names}

datamodule = cls(**datamodule_kwargs, **special_kwargs)
if train_dataset is not None:
datamodule.train_dataloader = train_dataloader
datamodule.train_dataloader = train_dataloader # type: ignore[assignment]
if val_dataset is not None:
datamodule.val_dataloader = val_dataloader
datamodule.val_dataloader = val_dataloader # type: ignore[assignment]
if test_dataset is not None:
datamodule.test_dataloader = test_dataloader
datamodule.test_dataloader = test_dataloader # type: ignore[assignment]
if predict_dataset is not None:
datamodule.predict_dataloader = predict_dataloader
datamodule.predict_dataloader = predict_dataloader # type: ignore[assignment]
return datamodule

def state_dict(self) -> Dict[str, Any]:
Expand All @@ -196,8 +225,8 @@ def load_from_checkpoint(
cls,
checkpoint_path: Union[_PATH, IO],
hparams_file: Optional[_PATH] = None,
**kwargs,
):
**kwargs: Any,
) -> Union["pl.LightningModule", "pl.LightningDataModule"]:
r"""
Primary way of loading a datamodule from a checkpoint. When Lightning saves a checkpoint
it stores the arguments passed to ``__init__`` in the checkpoint under ``"datamodule_hyper_parameters"``.
Expand Down
13 changes: 6 additions & 7 deletions src/pytorch_lightning/core/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,10 @@ def on_hpc_load(self, checkpoint: Dict[str, Any]) -> None:

def _load_from_checkpoint(
cls: Union[Type["ModelIO"], Type["pl.LightningModule"], Type["pl.LightningDataModule"]],
checkpoint_path: Union[str, IO],
checkpoint_path: Union[_PATH, IO],
map_location: _MAP_LOCATION_TYPE = None,
hparams_file: Optional[str] = None,
strict: bool = True,
hparams_file: Optional[_PATH] = None,
strict: Optional[bool] = None,
**kwargs: Any,
) -> Union["pl.LightningModule", "pl.LightningDataModule"]:
if map_location is None:
Expand All @@ -183,7 +183,7 @@ def _load_from_checkpoint(
checkpoint = pl_load(checkpoint_path, map_location=map_location)

if hparams_file is not None:
extension = hparams_file.split(".")[-1]
extension = str(hparams_file).split(".")[-1]
if extension.lower() == "csv":
hparams = load_hparams_from_tags_csv(hparams_file)
elif extension.lower() in ("yml", "yaml"):
Expand All @@ -201,16 +201,14 @@ def _load_from_checkpoint(

if issubclass(cls, pl.LightningDataModule):
return _load_state(cls, checkpoint, **kwargs)
# allow cls to be evaluated as subclassed LightningModule or,
# as LightningModule for internal tests
if issubclass(cls, pl.LightningModule):
return _load_state(cls, checkpoint, strict=strict, **kwargs)


def _load_state(
cls: Union[Type["pl.LightningModule"], Type["pl.LightningDataModule"]],
checkpoint: Dict[str, Any],
strict: bool = True,
strict: Optional[bool] = None,
**cls_kwargs_new: Any,
) -> Union["pl.LightningModule", "pl.LightningDataModule"]:
cls_spec = inspect.getfullargspec(cls.__init__)
Expand Down Expand Up @@ -257,6 +255,7 @@ def _load_state(
return obj

# load the state_dict on the model automatically
assert strict is not None
keys = obj.load_state_dict(checkpoint["state_dict"], strict=strict)

if not strict:
Expand Down
39 changes: 18 additions & 21 deletions src/pytorch_lightning/utilities/argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import inspect
import os
from abc import ABC
from argparse import _ArgumentGroup, ArgumentParser, Namespace
from ast import literal_eval
from contextlib import suppress
Expand All @@ -24,22 +23,17 @@

import pytorch_lightning as pl
from pytorch_lightning.utilities.parsing import str_to_bool, str_to_bool_or_int, str_to_bool_or_str
from pytorch_lightning.utilities.types import _ADD_ARGPARSE_RETURN

_T = TypeVar("_T", bound=Callable[..., Any])


class ParseArgparserDataType(ABC):
def __init__(self, *_: Any, **__: Any) -> None:
pass

@classmethod
def parse_argparser(cls, args: "ArgumentParser") -> Any:
pass
_ARGPARSE_CLS = Union[Type["pl.LightningDataModule"], Type["pl.Trainer"]]


def from_argparse_args(
cls: Type[ParseArgparserDataType], args: Union[Namespace, ArgumentParser], **kwargs: Any
) -> ParseArgparserDataType:
cls: _ARGPARSE_CLS,
args: Union[Namespace, ArgumentParser],
**kwargs: Any,
) -> Union["pl.LightningDataModule", "pl.Trainer"]:
"""Create an instance from CLI arguments. Eventually use variables from OS environment which are defined as
``"PL_<CLASS-NAME>_<CLASS_ARUMENT_NAME>"``.

Expand Down Expand Up @@ -72,7 +66,7 @@ def from_argparse_args(
return cls(**trainer_kwargs)


def parse_argparser(cls: Type["pl.Trainer"], arg_parser: Union[ArgumentParser, Namespace]) -> Namespace:
def parse_argparser(cls: _ARGPARSE_CLS, arg_parser: Union[ArgumentParser, Namespace]) -> Namespace:
"""Parse CLI arguments, required for custom bool types."""
args = arg_parser.parse_args() if isinstance(arg_parser, ArgumentParser) else arg_parser

Expand All @@ -97,7 +91,7 @@ def parse_argparser(cls: Type["pl.Trainer"], arg_parser: Union[ArgumentParser, N
return Namespace(**modified_args)


def parse_env_variables(cls: Type["pl.Trainer"], template: str = "PL_%(cls_name)s_%(cls_argument)s") -> Namespace:
def parse_env_variables(cls: _ARGPARSE_CLS, template: str = "PL_%(cls_name)s_%(cls_argument)s") -> Namespace:
"""Parse environment arguments if they are defined.

Examples:
Expand Down Expand Up @@ -127,7 +121,7 @@ def parse_env_variables(cls: Type["pl.Trainer"], template: str = "PL_%(cls_name)
return Namespace(**env_args)


def get_init_arguments_and_types(cls: Any) -> List[Tuple[str, Tuple, Any]]:
def get_init_arguments_and_types(cls: _ARGPARSE_CLS) -> List[Tuple[str, Tuple, Any]]:
r"""Scans the class signature and returns argument names, types and default values.

Returns:
Expand Down Expand Up @@ -155,7 +149,7 @@ def get_init_arguments_and_types(cls: Any) -> List[Tuple[str, Tuple, Any]]:
return name_type_default


def _get_abbrev_qualified_cls_name(cls: Any) -> str:
def _get_abbrev_qualified_cls_name(cls: _ARGPARSE_CLS) -> str:
assert isinstance(cls, type), repr(cls)
if cls.__module__.startswith("pytorch_lightning."):
# Abbreviate.
Expand All @@ -165,8 +159,11 @@ def _get_abbrev_qualified_cls_name(cls: Any) -> str:


def add_argparse_args(
cls: Type["pl.Trainer"], parent_parser: ArgumentParser, *, use_argument_group: bool = True
) -> Union[_ArgumentGroup, ArgumentParser]:
cls: _ARGPARSE_CLS,
parent_parser: ArgumentParser,
*,
use_argument_group: bool = True,
) -> _ADD_ARGPARSE_RETURN:
r"""Extends existing argparse by default attributes for ``cls``.

Args:
Expand Down Expand Up @@ -207,10 +204,10 @@ def add_argparse_args(
>>> args = parser.parse_args([])
"""
if isinstance(parent_parser, _ArgumentGroup):
raise RuntimeError("Please only pass an ArgumentParser instance.")
raise RuntimeError("Please only pass an `ArgumentParser` instance.")
if use_argument_group:
group_name = _get_abbrev_qualified_cls_name(cls)
parser: Union[_ArgumentGroup, ArgumentParser] = parent_parser.add_argument_group(group_name)
parser: _ADD_ARGPARSE_RETURN = parent_parser.add_argument_group(group_name)
else:
parser = ArgumentParser(parents=[parent_parser], add_help=False)

Expand All @@ -222,7 +219,7 @@ def add_argparse_args(

# Get symbols from cls or init function.
for symbol in (cls, cls.__init__):
args_and_types = get_init_arguments_and_types(symbol)
args_and_types = get_init_arguments_and_types(symbol) # type: ignore[arg-type]
args_and_types = [x for x in args_and_types if x[0] not in ignore_arg_names]
if len(args_and_types) > 0:
break
Expand Down
2 changes: 2 additions & 0 deletions src/pytorch_lightning/utilities/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
- Do not include any `_TYPE` suffix
- Types used in public hooks (as those in the `LightningModule` and `Callback`) should be public (no leading `_`)
"""
from argparse import _ArgumentGroup, ArgumentParser
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
Expand Down Expand Up @@ -50,6 +51,7 @@
EVAL_DATALOADERS = Union[DataLoader, Sequence[DataLoader]]
_DEVICE = Union[torch.device, str, int]
_MAP_LOCATION_TYPE = Optional[Union[_DEVICE, Callable[[_DEVICE], _DEVICE], Dict[_DEVICE, _DEVICE]]]
_ADD_ARGPARSE_RETURN = Union[_ArgumentGroup, ArgumentParser]


@runtime_checkable
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_pytorch/utilities/test_argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def test_add_argparse_args(cls, name):


def test_negative_add_argparse_args():
with pytest.raises(RuntimeError, match="Please only pass an ArgumentParser instance."):
with pytest.raises(RuntimeError, match="Please only pass an `ArgumentParser` instance."):
parser = ArgumentParser()
add_argparse_args(AddArgparseArgsExampleClass, parser.add_argument_group("bad workflow"))

Expand Down