Skip to content

Commit ed84d04

Browse files
jxtngxcarmoccarohitgr7otaj
authored
Fix mypy errors attributed to pytorch_lightning.core.datamodule (#13693)
Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: rohitgr7 <[email protected]> Co-authored-by: otaj <[email protected]>
1 parent fafd254 commit ed84d04

File tree

6 files changed

+74
-49
lines changed

6 files changed

+74
-49
lines changed

pyproject.toml

-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@ warn_no_return = "False"
5050
# mypy --no-error-summary 2>&1 | tr ':' ' ' | awk '{print $1}' | sort | uniq | sed 's/\.py//g; s|src/||g; s|\/|\.|g' | xargs -I {} echo '"{}",'
5151
module = [
5252
"pytorch_lightning.callbacks.progress.rich_progress",
53-
"pytorch_lightning.core.datamodule",
5453
"pytorch_lightning.profilers.base",
5554
"pytorch_lightning.profilers.pytorch",
5655
"pytorch_lightning.strategies.sharded",

src/pytorch_lightning/core/datamodule.py

+47-19
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,13 @@
2222
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks
2323
from pytorch_lightning.core.mixins import HyperparametersMixin
2424
from pytorch_lightning.core.saving import _load_from_checkpoint
25-
from pytorch_lightning.utilities.argparse import add_argparse_args, from_argparse_args, get_init_arguments_and_types
26-
from pytorch_lightning.utilities.types import _PATH
25+
from pytorch_lightning.utilities.argparse import (
26+
add_argparse_args,
27+
from_argparse_args,
28+
get_init_arguments_and_types,
29+
parse_argparser,
30+
)
31+
from pytorch_lightning.utilities.types import _ADD_ARGPARSE_RETURN, _PATH, EVAL_DATALOADERS, TRAIN_DATALOADERS
2732

2833

2934
class LightningDataModule(CheckpointHooks, DataHooks, HyperparametersMixin):
@@ -55,7 +60,7 @@ def teardown(self):
5560
# called on every process in DDP
5661
"""
5762

58-
name: str = ...
63+
name: Optional[str] = None
5964
CHECKPOINT_HYPER_PARAMS_KEY = "datamodule_hyper_parameters"
6065
CHECKPOINT_HYPER_PARAMS_NAME = "datamodule_hparams_name"
6166
CHECKPOINT_HYPER_PARAMS_TYPE = "datamodule_hparams_type"
@@ -66,7 +71,7 @@ def __init__(self) -> None:
6671
self.trainer: Optional["pl.Trainer"] = None
6772

6873
@classmethod
69-
def add_argparse_args(cls, parent_parser: ArgumentParser, **kwargs) -> ArgumentParser:
74+
def add_argparse_args(cls, parent_parser: ArgumentParser, **kwargs: Any) -> _ADD_ARGPARSE_RETURN:
7075
"""Extends existing argparse by default `LightningDataModule` attributes.
7176
7277
Example::
@@ -77,7 +82,9 @@ def add_argparse_args(cls, parent_parser: ArgumentParser, **kwargs) -> ArgumentP
7782
return add_argparse_args(cls, parent_parser, **kwargs)
7883

7984
@classmethod
80-
def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs):
85+
def from_argparse_args(
86+
cls, args: Union[Namespace, ArgumentParser], **kwargs: Any
87+
) -> Union["pl.LightningDataModule", "pl.Trainer"]:
8188
"""Create an instance from CLI arguments.
8289
8390
Args:
@@ -92,6 +99,10 @@ def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs):
9299
"""
93100
return from_argparse_args(cls, args, **kwargs)
94101

102+
@classmethod
103+
def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namespace:
104+
return parse_argparser(cls, arg_parser)
105+
95106
@classmethod
96107
def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]:
97108
r"""Scans the DataModule signature and returns argument names, types and default values.
@@ -102,6 +113,15 @@ def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]:
102113
"""
103114
return get_init_arguments_and_types(cls)
104115

116+
@classmethod
117+
def get_deprecated_arg_names(cls) -> List:
118+
"""Returns a list with deprecated DataModule arguments."""
119+
depr_arg_names: List[str] = []
120+
for name, val in cls.__dict__.items():
121+
if name.startswith("DEPRECATED") and isinstance(val, (tuple, list)):
122+
depr_arg_names.extend(val)
123+
return depr_arg_names
124+
105125
@classmethod
106126
def from_datasets(
107127
cls,
@@ -112,7 +132,7 @@ def from_datasets(
112132
batch_size: int = 1,
113133
num_workers: int = 0,
114134
**datamodule_kwargs: Any,
115-
):
135+
) -> "LightningDataModule":
116136
r"""
117137
Create an instance from torch.utils.data.Dataset.
118138
@@ -133,24 +153,32 @@ def dataloader(ds: Dataset, shuffle: bool = False) -> DataLoader:
133153
shuffle &= not isinstance(ds, IterableDataset)
134154
return DataLoader(ds, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True)
135155

136-
def train_dataloader():
156+
def train_dataloader() -> TRAIN_DATALOADERS:
157+
assert train_dataset
158+
137159
if isinstance(train_dataset, Mapping):
138160
return {key: dataloader(ds, shuffle=True) for key, ds in train_dataset.items()}
139161
if isinstance(train_dataset, Sequence):
140162
return [dataloader(ds, shuffle=True) for ds in train_dataset]
141163
return dataloader(train_dataset, shuffle=True)
142164

143-
def val_dataloader():
165+
def val_dataloader() -> EVAL_DATALOADERS:
166+
assert val_dataset
167+
144168
if isinstance(val_dataset, Sequence):
145169
return [dataloader(ds) for ds in val_dataset]
146170
return dataloader(val_dataset)
147171

148-
def test_dataloader():
172+
def test_dataloader() -> EVAL_DATALOADERS:
173+
assert test_dataset
174+
149175
if isinstance(test_dataset, Sequence):
150176
return [dataloader(ds) for ds in test_dataset]
151177
return dataloader(test_dataset)
152178

153-
def predict_dataloader():
179+
def predict_dataloader() -> EVAL_DATALOADERS:
180+
assert predict_dataset
181+
154182
if isinstance(predict_dataset, Sequence):
155183
return [dataloader(ds) for ds in predict_dataset]
156184
return dataloader(predict_dataset)
@@ -161,19 +189,19 @@ def predict_dataloader():
161189
if accepts_kwargs:
162190
special_kwargs = candidate_kwargs
163191
else:
164-
accepted_params = set(accepted_params)
165-
accepted_params.discard("self")
166-
special_kwargs = {k: v for k, v in candidate_kwargs.items() if k in accepted_params}
192+
accepted_param_names = set(accepted_params)
193+
accepted_param_names.discard("self")
194+
special_kwargs = {k: v for k, v in candidate_kwargs.items() if k in accepted_param_names}
167195

168196
datamodule = cls(**datamodule_kwargs, **special_kwargs)
169197
if train_dataset is not None:
170-
datamodule.train_dataloader = train_dataloader
198+
datamodule.train_dataloader = train_dataloader # type: ignore[assignment]
171199
if val_dataset is not None:
172-
datamodule.val_dataloader = val_dataloader
200+
datamodule.val_dataloader = val_dataloader # type: ignore[assignment]
173201
if test_dataset is not None:
174-
datamodule.test_dataloader = test_dataloader
202+
datamodule.test_dataloader = test_dataloader # type: ignore[assignment]
175203
if predict_dataset is not None:
176-
datamodule.predict_dataloader = predict_dataloader
204+
datamodule.predict_dataloader = predict_dataloader # type: ignore[assignment]
177205
return datamodule
178206

179207
def state_dict(self) -> Dict[str, Any]:
@@ -197,8 +225,8 @@ def load_from_checkpoint(
197225
cls,
198226
checkpoint_path: Union[_PATH, IO],
199227
hparams_file: Optional[_PATH] = None,
200-
**kwargs,
201-
):
228+
**kwargs: Any,
229+
) -> Union["pl.LightningModule", "pl.LightningDataModule"]:
202230
r"""
203231
Primary way of loading a datamodule from a checkpoint. When Lightning saves a checkpoint
204232
it stores the arguments passed to ``__init__`` in the checkpoint under ``"datamodule_hyper_parameters"``.

src/pytorch_lightning/core/saving.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -171,10 +171,10 @@ def on_hpc_load(self, checkpoint: Dict[str, Any]) -> None:
171171

172172
def _load_from_checkpoint(
173173
cls: Union[Type["ModelIO"], Type["pl.LightningModule"], Type["pl.LightningDataModule"]],
174-
checkpoint_path: Union[str, IO],
174+
checkpoint_path: Union[_PATH, IO],
175175
map_location: _MAP_LOCATION_TYPE = None,
176-
hparams_file: Optional[str] = None,
177-
strict: bool = True,
176+
hparams_file: Optional[_PATH] = None,
177+
strict: Optional[bool] = None,
178178
**kwargs: Any,
179179
) -> Union["pl.LightningModule", "pl.LightningDataModule"]:
180180
if map_location is None:
@@ -183,7 +183,7 @@ def _load_from_checkpoint(
183183
checkpoint = pl_load(checkpoint_path, map_location=map_location)
184184

185185
if hparams_file is not None:
186-
extension = hparams_file.split(".")[-1]
186+
extension = str(hparams_file).split(".")[-1]
187187
if extension.lower() == "csv":
188188
hparams = load_hparams_from_tags_csv(hparams_file)
189189
elif extension.lower() in ("yml", "yaml"):
@@ -201,16 +201,14 @@ def _load_from_checkpoint(
201201

202202
if issubclass(cls, pl.LightningDataModule):
203203
return _load_state(cls, checkpoint, **kwargs)
204-
# allow cls to be evaluated as subclassed LightningModule or,
205-
# as LightningModule for internal tests
206204
if issubclass(cls, pl.LightningModule):
207205
return _load_state(cls, checkpoint, strict=strict, **kwargs)
208206

209207

210208
def _load_state(
211209
cls: Union[Type["pl.LightningModule"], Type["pl.LightningDataModule"]],
212210
checkpoint: Dict[str, Any],
213-
strict: bool = True,
211+
strict: Optional[bool] = None,
214212
**cls_kwargs_new: Any,
215213
) -> Union["pl.LightningModule", "pl.LightningDataModule"]:
216214
cls_spec = inspect.getfullargspec(cls.__init__)
@@ -257,6 +255,7 @@ def _load_state(
257255
return obj
258256

259257
# load the state_dict on the model automatically
258+
assert strict is not None
260259
keys = obj.load_state_dict(checkpoint["state_dict"], strict=strict)
261260

262261
if not strict:

src/pytorch_lightning/utilities/argparse.py

+18-21
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
import inspect
1717
import os
18-
from abc import ABC
1918
from argparse import _ArgumentGroup, ArgumentParser, Namespace
2019
from ast import literal_eval
2120
from contextlib import suppress
@@ -24,22 +23,17 @@
2423

2524
import pytorch_lightning as pl
2625
from pytorch_lightning.utilities.parsing import str_to_bool, str_to_bool_or_int, str_to_bool_or_str
26+
from pytorch_lightning.utilities.types import _ADD_ARGPARSE_RETURN
2727

2828
_T = TypeVar("_T", bound=Callable[..., Any])
29-
30-
31-
class ParseArgparserDataType(ABC):
32-
def __init__(self, *_: Any, **__: Any) -> None:
33-
pass
34-
35-
@classmethod
36-
def parse_argparser(cls, args: "ArgumentParser") -> Any:
37-
pass
29+
_ARGPARSE_CLS = Union[Type["pl.LightningDataModule"], Type["pl.Trainer"]]
3830

3931

4032
def from_argparse_args(
41-
cls: Type[ParseArgparserDataType], args: Union[Namespace, ArgumentParser], **kwargs: Any
42-
) -> ParseArgparserDataType:
33+
cls: _ARGPARSE_CLS,
34+
args: Union[Namespace, ArgumentParser],
35+
**kwargs: Any,
36+
) -> Union["pl.LightningDataModule", "pl.Trainer"]:
4337
"""Create an instance from CLI arguments. Eventually use variables from OS environment which are defined as
4438
``"PL_<CLASS-NAME>_<CLASS_ARUMENT_NAME>"``.
4539
@@ -72,7 +66,7 @@ def from_argparse_args(
7266
return cls(**trainer_kwargs)
7367

7468

75-
def parse_argparser(cls: Type["pl.Trainer"], arg_parser: Union[ArgumentParser, Namespace]) -> Namespace:
69+
def parse_argparser(cls: _ARGPARSE_CLS, arg_parser: Union[ArgumentParser, Namespace]) -> Namespace:
7670
"""Parse CLI arguments, required for custom bool types."""
7771
args = arg_parser.parse_args() if isinstance(arg_parser, ArgumentParser) else arg_parser
7872

@@ -97,7 +91,7 @@ def parse_argparser(cls: Type["pl.Trainer"], arg_parser: Union[ArgumentParser, N
9791
return Namespace(**modified_args)
9892

9993

100-
def parse_env_variables(cls: Type["pl.Trainer"], template: str = "PL_%(cls_name)s_%(cls_argument)s") -> Namespace:
94+
def parse_env_variables(cls: _ARGPARSE_CLS, template: str = "PL_%(cls_name)s_%(cls_argument)s") -> Namespace:
10195
"""Parse environment arguments if they are defined.
10296
10397
Examples:
@@ -127,7 +121,7 @@ def parse_env_variables(cls: Type["pl.Trainer"], template: str = "PL_%(cls_name)
127121
return Namespace(**env_args)
128122

129123

130-
def get_init_arguments_and_types(cls: Any) -> List[Tuple[str, Tuple, Any]]:
124+
def get_init_arguments_and_types(cls: _ARGPARSE_CLS) -> List[Tuple[str, Tuple, Any]]:
131125
r"""Scans the class signature and returns argument names, types and default values.
132126
133127
Returns:
@@ -155,7 +149,7 @@ def get_init_arguments_and_types(cls: Any) -> List[Tuple[str, Tuple, Any]]:
155149
return name_type_default
156150

157151

158-
def _get_abbrev_qualified_cls_name(cls: Any) -> str:
152+
def _get_abbrev_qualified_cls_name(cls: _ARGPARSE_CLS) -> str:
159153
assert isinstance(cls, type), repr(cls)
160154
if cls.__module__.startswith("pytorch_lightning."):
161155
# Abbreviate.
@@ -165,8 +159,11 @@ def _get_abbrev_qualified_cls_name(cls: Any) -> str:
165159

166160

167161
def add_argparse_args(
168-
cls: Type["pl.Trainer"], parent_parser: ArgumentParser, *, use_argument_group: bool = True
169-
) -> Union[_ArgumentGroup, ArgumentParser]:
162+
cls: _ARGPARSE_CLS,
163+
parent_parser: ArgumentParser,
164+
*,
165+
use_argument_group: bool = True,
166+
) -> _ADD_ARGPARSE_RETURN:
170167
r"""Extends existing argparse by default attributes for ``cls``.
171168
172169
Args:
@@ -207,10 +204,10 @@ def add_argparse_args(
207204
>>> args = parser.parse_args([])
208205
"""
209206
if isinstance(parent_parser, _ArgumentGroup):
210-
raise RuntimeError("Please only pass an ArgumentParser instance.")
207+
raise RuntimeError("Please only pass an `ArgumentParser` instance.")
211208
if use_argument_group:
212209
group_name = _get_abbrev_qualified_cls_name(cls)
213-
parser: Union[_ArgumentGroup, ArgumentParser] = parent_parser.add_argument_group(group_name)
210+
parser: _ADD_ARGPARSE_RETURN = parent_parser.add_argument_group(group_name)
214211
else:
215212
parser = ArgumentParser(parents=[parent_parser], add_help=False)
216213

@@ -222,7 +219,7 @@ def add_argparse_args(
222219

223220
# Get symbols from cls or init function.
224221
for symbol in (cls, cls.__init__):
225-
args_and_types = get_init_arguments_and_types(symbol)
222+
args_and_types = get_init_arguments_and_types(symbol) # type: ignore[arg-type]
226223
args_and_types = [x for x in args_and_types if x[0] not in ignore_arg_names]
227224
if len(args_and_types) > 0:
228225
break

src/pytorch_lightning/utilities/types.py

+2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
- Do not include any `_TYPE` suffix
1717
- Types used in public hooks (as those in the `LightningModule` and `Callback`) should be public (no leading `_`)
1818
"""
19+
from argparse import _ArgumentGroup, ArgumentParser
1920
from contextlib import contextmanager
2021
from dataclasses import dataclass
2122
from pathlib import Path
@@ -50,6 +51,7 @@
5051
EVAL_DATALOADERS = Union[DataLoader, Sequence[DataLoader]]
5152
_DEVICE = Union[torch.device, str, int]
5253
_MAP_LOCATION_TYPE = Optional[Union[_DEVICE, Callable[[_DEVICE], _DEVICE], Dict[_DEVICE, _DEVICE]]]
54+
_ADD_ARGPARSE_RETURN = Union[_ArgumentGroup, ArgumentParser]
5355

5456

5557
@runtime_checkable

tests/tests_pytorch/utilities/test_argparse.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def test_add_argparse_args(cls, name):
207207

208208

209209
def test_negative_add_argparse_args():
210-
with pytest.raises(RuntimeError, match="Please only pass an ArgumentParser instance."):
210+
with pytest.raises(RuntimeError, match="Please only pass an `ArgumentParser` instance."):
211211
parser = ArgumentParser()
212212
add_argparse_args(AddArgparseArgsExampleClass, parser.add_argument_group("bad workflow"))
213213

0 commit comments

Comments
 (0)