diff --git a/pyproject.toml b/pyproject.toml index 18ae22feeb963..45f8ae678ec0c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,6 @@ warn_no_return = "False" # mypy --no-error-summary 2>&1 | tr ':' ' ' | awk '{print $1}' | sort | uniq | sed 's/\.py//g; s|src/||g; s|\/|\.|g' | xargs -I {} echo '"{}",' module = [ "pytorch_lightning.callbacks.progress.rich_progress", - "pytorch_lightning.core.datamodule", "pytorch_lightning.profilers.base", "pytorch_lightning.profilers.pytorch", "pytorch_lightning.strategies.sharded", diff --git a/src/pytorch_lightning/core/datamodule.py b/src/pytorch_lightning/core/datamodule.py index 4edde3fe6a3ae..1bd19bc6e2657 100644 --- a/src/pytorch_lightning/core/datamodule.py +++ b/src/pytorch_lightning/core/datamodule.py @@ -18,11 +18,17 @@ from torch.utils.data import DataLoader, Dataset, IterableDataset +import pytorch_lightning as pl from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks from pytorch_lightning.core.mixins import HyperparametersMixin from pytorch_lightning.core.saving import _load_from_checkpoint -from pytorch_lightning.utilities.argparse import add_argparse_args, from_argparse_args, get_init_arguments_and_types -from pytorch_lightning.utilities.types import _PATH +from pytorch_lightning.utilities.argparse import ( + add_argparse_args, + from_argparse_args, + get_init_arguments_and_types, + parse_argparser, +) +from pytorch_lightning.utilities.types import _ADD_ARGPARSE_RETURN, _PATH, EVAL_DATALOADERS, TRAIN_DATALOADERS class LightningDataModule(CheckpointHooks, DataHooks, HyperparametersMixin): @@ -54,7 +60,7 @@ def teardown(self): # called on every process in DDP """ - name: str = ... + name: Optional[str] = None CHECKPOINT_HYPER_PARAMS_KEY = "datamodule_hyper_parameters" CHECKPOINT_HYPER_PARAMS_NAME = "datamodule_hparams_name" CHECKPOINT_HYPER_PARAMS_TYPE = "datamodule_hparams_type" @@ -65,7 +71,7 @@ def __init__(self) -> None: self.trainer = None @classmethod - def add_argparse_args(cls, parent_parser: ArgumentParser, **kwargs) -> ArgumentParser: + def add_argparse_args(cls, parent_parser: ArgumentParser, **kwargs: Any) -> _ADD_ARGPARSE_RETURN: """Extends existing argparse by default `LightningDataModule` attributes. Example:: @@ -76,7 +82,9 @@ def add_argparse_args(cls, parent_parser: ArgumentParser, **kwargs) -> ArgumentP return add_argparse_args(cls, parent_parser, **kwargs) @classmethod - def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs): + def from_argparse_args( + cls, args: Union[Namespace, ArgumentParser], **kwargs: Any + ) -> Union["pl.LightningDataModule", "pl.Trainer"]: """Create an instance from CLI arguments. Args: @@ -91,6 +99,10 @@ def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs): """ return from_argparse_args(cls, args, **kwargs) + @classmethod + def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namespace: + return parse_argparser(cls, arg_parser) + @classmethod def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]: r"""Scans the DataModule signature and returns argument names, types and default values. @@ -101,6 +113,15 @@ def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]: """ return get_init_arguments_and_types(cls) + @classmethod + def get_deprecated_arg_names(cls) -> List: + """Returns a list with deprecated DataModule arguments.""" + depr_arg_names: List[str] = [] + for name, val in cls.__dict__.items(): + if name.startswith("DEPRECATED") and isinstance(val, (tuple, list)): + depr_arg_names.extend(val) + return depr_arg_names + @classmethod def from_datasets( cls, @@ -111,7 +132,7 @@ def from_datasets( batch_size: int = 1, num_workers: int = 0, **datamodule_kwargs: Any, - ): + ) -> "LightningDataModule": r""" Create an instance from torch.utils.data.Dataset. @@ -132,24 +153,32 @@ def dataloader(ds: Dataset, shuffle: bool = False) -> DataLoader: shuffle &= not isinstance(ds, IterableDataset) return DataLoader(ds, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True) - def train_dataloader(): + def train_dataloader() -> TRAIN_DATALOADERS: + assert train_dataset + if isinstance(train_dataset, Mapping): return {key: dataloader(ds, shuffle=True) for key, ds in train_dataset.items()} if isinstance(train_dataset, Sequence): return [dataloader(ds, shuffle=True) for ds in train_dataset] return dataloader(train_dataset, shuffle=True) - def val_dataloader(): + def val_dataloader() -> EVAL_DATALOADERS: + assert val_dataset + if isinstance(val_dataset, Sequence): return [dataloader(ds) for ds in val_dataset] return dataloader(val_dataset) - def test_dataloader(): + def test_dataloader() -> EVAL_DATALOADERS: + assert test_dataset + if isinstance(test_dataset, Sequence): return [dataloader(ds) for ds in test_dataset] return dataloader(test_dataset) - def predict_dataloader(): + def predict_dataloader() -> EVAL_DATALOADERS: + assert predict_dataset + if isinstance(predict_dataset, Sequence): return [dataloader(ds) for ds in predict_dataset] return dataloader(predict_dataset) @@ -160,19 +189,19 @@ def predict_dataloader(): if accepts_kwargs: special_kwargs = candidate_kwargs else: - accepted_params = set(accepted_params) - accepted_params.discard("self") - special_kwargs = {k: v for k, v in candidate_kwargs.items() if k in accepted_params} + accepted_param_names = set(accepted_params) + accepted_param_names.discard("self") + special_kwargs = {k: v for k, v in candidate_kwargs.items() if k in accepted_param_names} datamodule = cls(**datamodule_kwargs, **special_kwargs) if train_dataset is not None: - datamodule.train_dataloader = train_dataloader + datamodule.train_dataloader = train_dataloader # type: ignore[assignment] if val_dataset is not None: - datamodule.val_dataloader = val_dataloader + datamodule.val_dataloader = val_dataloader # type: ignore[assignment] if test_dataset is not None: - datamodule.test_dataloader = test_dataloader + datamodule.test_dataloader = test_dataloader # type: ignore[assignment] if predict_dataset is not None: - datamodule.predict_dataloader = predict_dataloader + datamodule.predict_dataloader = predict_dataloader # type: ignore[assignment] return datamodule def state_dict(self) -> Dict[str, Any]: @@ -196,8 +225,8 @@ def load_from_checkpoint( cls, checkpoint_path: Union[_PATH, IO], hparams_file: Optional[_PATH] = None, - **kwargs, - ): + **kwargs: Any, + ) -> Union["pl.LightningModule", "pl.LightningDataModule"]: r""" Primary way of loading a datamodule from a checkpoint. When Lightning saves a checkpoint it stores the arguments passed to ``__init__`` in the checkpoint under ``"datamodule_hyper_parameters"``. diff --git a/src/pytorch_lightning/core/saving.py b/src/pytorch_lightning/core/saving.py index ffdc0988a1a6e..2df7a661c2bd6 100644 --- a/src/pytorch_lightning/core/saving.py +++ b/src/pytorch_lightning/core/saving.py @@ -171,10 +171,10 @@ def on_hpc_load(self, checkpoint: Dict[str, Any]) -> None: def _load_from_checkpoint( cls: Union[Type["ModelIO"], Type["pl.LightningModule"], Type["pl.LightningDataModule"]], - checkpoint_path: Union[str, IO], + checkpoint_path: Union[_PATH, IO], map_location: _MAP_LOCATION_TYPE = None, - hparams_file: Optional[str] = None, - strict: bool = True, + hparams_file: Optional[_PATH] = None, + strict: Optional[bool] = None, **kwargs: Any, ) -> Union["pl.LightningModule", "pl.LightningDataModule"]: if map_location is None: @@ -183,7 +183,7 @@ def _load_from_checkpoint( checkpoint = pl_load(checkpoint_path, map_location=map_location) if hparams_file is not None: - extension = hparams_file.split(".")[-1] + extension = str(hparams_file).split(".")[-1] if extension.lower() == "csv": hparams = load_hparams_from_tags_csv(hparams_file) elif extension.lower() in ("yml", "yaml"): @@ -201,8 +201,6 @@ def _load_from_checkpoint( if issubclass(cls, pl.LightningDataModule): return _load_state(cls, checkpoint, **kwargs) - # allow cls to be evaluated as subclassed LightningModule or, - # as LightningModule for internal tests if issubclass(cls, pl.LightningModule): return _load_state(cls, checkpoint, strict=strict, **kwargs) @@ -210,7 +208,7 @@ def _load_from_checkpoint( def _load_state( cls: Union[Type["pl.LightningModule"], Type["pl.LightningDataModule"]], checkpoint: Dict[str, Any], - strict: bool = True, + strict: Optional[bool] = None, **cls_kwargs_new: Any, ) -> Union["pl.LightningModule", "pl.LightningDataModule"]: cls_spec = inspect.getfullargspec(cls.__init__) @@ -257,6 +255,7 @@ def _load_state( return obj # load the state_dict on the model automatically + assert strict is not None keys = obj.load_state_dict(checkpoint["state_dict"], strict=strict) if not strict: diff --git a/src/pytorch_lightning/utilities/argparse.py b/src/pytorch_lightning/utilities/argparse.py index 58ced375fcae5..26277db183410 100644 --- a/src/pytorch_lightning/utilities/argparse.py +++ b/src/pytorch_lightning/utilities/argparse.py @@ -15,7 +15,6 @@ import inspect import os -from abc import ABC from argparse import _ArgumentGroup, ArgumentParser, Namespace from ast import literal_eval from contextlib import suppress @@ -24,22 +23,17 @@ import pytorch_lightning as pl from pytorch_lightning.utilities.parsing import str_to_bool, str_to_bool_or_int, str_to_bool_or_str +from pytorch_lightning.utilities.types import _ADD_ARGPARSE_RETURN _T = TypeVar("_T", bound=Callable[..., Any]) - - -class ParseArgparserDataType(ABC): - def __init__(self, *_: Any, **__: Any) -> None: - pass - - @classmethod - def parse_argparser(cls, args: "ArgumentParser") -> Any: - pass +_ARGPARSE_CLS = Union[Type["pl.LightningDataModule"], Type["pl.Trainer"]] def from_argparse_args( - cls: Type[ParseArgparserDataType], args: Union[Namespace, ArgumentParser], **kwargs: Any -) -> ParseArgparserDataType: + cls: _ARGPARSE_CLS, + args: Union[Namespace, ArgumentParser], + **kwargs: Any, +) -> Union["pl.LightningDataModule", "pl.Trainer"]: """Create an instance from CLI arguments. Eventually use variables from OS environment which are defined as ``"PL__"``. @@ -72,7 +66,7 @@ def from_argparse_args( return cls(**trainer_kwargs) -def parse_argparser(cls: Type["pl.Trainer"], arg_parser: Union[ArgumentParser, Namespace]) -> Namespace: +def parse_argparser(cls: _ARGPARSE_CLS, arg_parser: Union[ArgumentParser, Namespace]) -> Namespace: """Parse CLI arguments, required for custom bool types.""" args = arg_parser.parse_args() if isinstance(arg_parser, ArgumentParser) else arg_parser @@ -97,7 +91,7 @@ def parse_argparser(cls: Type["pl.Trainer"], arg_parser: Union[ArgumentParser, N return Namespace(**modified_args) -def parse_env_variables(cls: Type["pl.Trainer"], template: str = "PL_%(cls_name)s_%(cls_argument)s") -> Namespace: +def parse_env_variables(cls: _ARGPARSE_CLS, template: str = "PL_%(cls_name)s_%(cls_argument)s") -> Namespace: """Parse environment arguments if they are defined. Examples: @@ -127,7 +121,7 @@ def parse_env_variables(cls: Type["pl.Trainer"], template: str = "PL_%(cls_name) return Namespace(**env_args) -def get_init_arguments_and_types(cls: Any) -> List[Tuple[str, Tuple, Any]]: +def get_init_arguments_and_types(cls: _ARGPARSE_CLS) -> List[Tuple[str, Tuple, Any]]: r"""Scans the class signature and returns argument names, types and default values. Returns: @@ -155,7 +149,7 @@ def get_init_arguments_and_types(cls: Any) -> List[Tuple[str, Tuple, Any]]: return name_type_default -def _get_abbrev_qualified_cls_name(cls: Any) -> str: +def _get_abbrev_qualified_cls_name(cls: _ARGPARSE_CLS) -> str: assert isinstance(cls, type), repr(cls) if cls.__module__.startswith("pytorch_lightning."): # Abbreviate. @@ -165,8 +159,11 @@ def _get_abbrev_qualified_cls_name(cls: Any) -> str: def add_argparse_args( - cls: Type["pl.Trainer"], parent_parser: ArgumentParser, *, use_argument_group: bool = True -) -> Union[_ArgumentGroup, ArgumentParser]: + cls: _ARGPARSE_CLS, + parent_parser: ArgumentParser, + *, + use_argument_group: bool = True, +) -> _ADD_ARGPARSE_RETURN: r"""Extends existing argparse by default attributes for ``cls``. Args: @@ -207,10 +204,10 @@ def add_argparse_args( >>> args = parser.parse_args([]) """ if isinstance(parent_parser, _ArgumentGroup): - raise RuntimeError("Please only pass an ArgumentParser instance.") + raise RuntimeError("Please only pass an `ArgumentParser` instance.") if use_argument_group: group_name = _get_abbrev_qualified_cls_name(cls) - parser: Union[_ArgumentGroup, ArgumentParser] = parent_parser.add_argument_group(group_name) + parser: _ADD_ARGPARSE_RETURN = parent_parser.add_argument_group(group_name) else: parser = ArgumentParser(parents=[parent_parser], add_help=False) @@ -222,7 +219,7 @@ def add_argparse_args( # Get symbols from cls or init function. for symbol in (cls, cls.__init__): - args_and_types = get_init_arguments_and_types(symbol) + args_and_types = get_init_arguments_and_types(symbol) # type: ignore[arg-type] args_and_types = [x for x in args_and_types if x[0] not in ignore_arg_names] if len(args_and_types) > 0: break diff --git a/src/pytorch_lightning/utilities/types.py b/src/pytorch_lightning/utilities/types.py index 9f2db6422612f..7ab3d6948854c 100644 --- a/src/pytorch_lightning/utilities/types.py +++ b/src/pytorch_lightning/utilities/types.py @@ -16,6 +16,7 @@ - Do not include any `_TYPE` suffix - Types used in public hooks (as those in the `LightningModule` and `Callback`) should be public (no leading `_`) """ +from argparse import _ArgumentGroup, ArgumentParser from contextlib import contextmanager from dataclasses import dataclass from pathlib import Path @@ -50,6 +51,7 @@ EVAL_DATALOADERS = Union[DataLoader, Sequence[DataLoader]] _DEVICE = Union[torch.device, str, int] _MAP_LOCATION_TYPE = Optional[Union[_DEVICE, Callable[[_DEVICE], _DEVICE], Dict[_DEVICE, _DEVICE]]] +_ADD_ARGPARSE_RETURN = Union[_ArgumentGroup, ArgumentParser] @runtime_checkable diff --git a/tests/tests_pytorch/utilities/test_argparse.py b/tests/tests_pytorch/utilities/test_argparse.py index 1e83d96a0648e..2a88e8db531f9 100644 --- a/tests/tests_pytorch/utilities/test_argparse.py +++ b/tests/tests_pytorch/utilities/test_argparse.py @@ -207,7 +207,7 @@ def test_add_argparse_args(cls, name): def test_negative_add_argparse_args(): - with pytest.raises(RuntimeError, match="Please only pass an ArgumentParser instance."): + with pytest.raises(RuntimeError, match="Please only pass an `ArgumentParser` instance."): parser = ArgumentParser() add_argparse_args(AddArgparseArgsExampleClass, parser.add_argument_group("bad workflow"))