22
22
from pytorch_lightning .core .hooks import CheckpointHooks , DataHooks
23
23
from pytorch_lightning .core .mixins import HyperparametersMixin
24
24
from pytorch_lightning .core .saving import _load_from_checkpoint
25
- from pytorch_lightning .utilities .argparse import add_argparse_args , from_argparse_args , get_init_arguments_and_types
26
- from pytorch_lightning .utilities .types import _PATH
25
+ from pytorch_lightning .utilities .argparse import (
26
+ add_argparse_args ,
27
+ from_argparse_args ,
28
+ get_init_arguments_and_types ,
29
+ parse_argparser ,
30
+ )
31
+ from pytorch_lightning .utilities .types import _ADD_ARGPARSE_RETURN , _PATH , EVAL_DATALOADERS , TRAIN_DATALOADERS
27
32
28
33
29
34
class LightningDataModule (CheckpointHooks , DataHooks , HyperparametersMixin ):
@@ -55,7 +60,7 @@ def teardown(self):
55
60
# called on every process in DDP
56
61
"""
57
62
58
- name : str = ...
63
+ name : Optional [ str ] = None
59
64
CHECKPOINT_HYPER_PARAMS_KEY = "datamodule_hyper_parameters"
60
65
CHECKPOINT_HYPER_PARAMS_NAME = "datamodule_hparams_name"
61
66
CHECKPOINT_HYPER_PARAMS_TYPE = "datamodule_hparams_type"
@@ -66,7 +71,7 @@ def __init__(self) -> None:
66
71
self .trainer : Optional ["pl.Trainer" ] = None
67
72
68
73
@classmethod
69
- def add_argparse_args (cls , parent_parser : ArgumentParser , ** kwargs ) -> ArgumentParser :
74
+ def add_argparse_args (cls , parent_parser : ArgumentParser , ** kwargs : Any ) -> _ADD_ARGPARSE_RETURN :
70
75
"""Extends existing argparse by default `LightningDataModule` attributes.
71
76
72
77
Example::
@@ -77,7 +82,9 @@ def add_argparse_args(cls, parent_parser: ArgumentParser, **kwargs) -> ArgumentP
77
82
return add_argparse_args (cls , parent_parser , ** kwargs )
78
83
79
84
@classmethod
80
- def from_argparse_args (cls , args : Union [Namespace , ArgumentParser ], ** kwargs ):
85
+ def from_argparse_args (
86
+ cls , args : Union [Namespace , ArgumentParser ], ** kwargs : Any
87
+ ) -> Union ["pl.LightningDataModule" , "pl.Trainer" ]:
81
88
"""Create an instance from CLI arguments.
82
89
83
90
Args:
@@ -92,6 +99,10 @@ def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs):
92
99
"""
93
100
return from_argparse_args (cls , args , ** kwargs )
94
101
102
+ @classmethod
103
+ def parse_argparser (cls , arg_parser : Union [ArgumentParser , Namespace ]) -> Namespace :
104
+ return parse_argparser (cls , arg_parser )
105
+
95
106
@classmethod
96
107
def get_init_arguments_and_types (cls ) -> List [Tuple [str , Tuple , Any ]]:
97
108
r"""Scans the DataModule signature and returns argument names, types and default values.
@@ -102,6 +113,15 @@ def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]:
102
113
"""
103
114
return get_init_arguments_and_types (cls )
104
115
116
+ @classmethod
117
+ def get_deprecated_arg_names (cls ) -> List :
118
+ """Returns a list with deprecated DataModule arguments."""
119
+ depr_arg_names : List [str ] = []
120
+ for name , val in cls .__dict__ .items ():
121
+ if name .startswith ("DEPRECATED" ) and isinstance (val , (tuple , list )):
122
+ depr_arg_names .extend (val )
123
+ return depr_arg_names
124
+
105
125
@classmethod
106
126
def from_datasets (
107
127
cls ,
@@ -112,7 +132,7 @@ def from_datasets(
112
132
batch_size : int = 1 ,
113
133
num_workers : int = 0 ,
114
134
** datamodule_kwargs : Any ,
115
- ):
135
+ ) -> "LightningDataModule" :
116
136
r"""
117
137
Create an instance from torch.utils.data.Dataset.
118
138
@@ -133,24 +153,32 @@ def dataloader(ds: Dataset, shuffle: bool = False) -> DataLoader:
133
153
shuffle &= not isinstance (ds , IterableDataset )
134
154
return DataLoader (ds , batch_size = batch_size , shuffle = shuffle , num_workers = num_workers , pin_memory = True )
135
155
136
- def train_dataloader ():
156
+ def train_dataloader () -> TRAIN_DATALOADERS :
157
+ assert train_dataset
158
+
137
159
if isinstance (train_dataset , Mapping ):
138
160
return {key : dataloader (ds , shuffle = True ) for key , ds in train_dataset .items ()}
139
161
if isinstance (train_dataset , Sequence ):
140
162
return [dataloader (ds , shuffle = True ) for ds in train_dataset ]
141
163
return dataloader (train_dataset , shuffle = True )
142
164
143
- def val_dataloader ():
165
+ def val_dataloader () -> EVAL_DATALOADERS :
166
+ assert val_dataset
167
+
144
168
if isinstance (val_dataset , Sequence ):
145
169
return [dataloader (ds ) for ds in val_dataset ]
146
170
return dataloader (val_dataset )
147
171
148
- def test_dataloader ():
172
+ def test_dataloader () -> EVAL_DATALOADERS :
173
+ assert test_dataset
174
+
149
175
if isinstance (test_dataset , Sequence ):
150
176
return [dataloader (ds ) for ds in test_dataset ]
151
177
return dataloader (test_dataset )
152
178
153
- def predict_dataloader ():
179
+ def predict_dataloader () -> EVAL_DATALOADERS :
180
+ assert predict_dataset
181
+
154
182
if isinstance (predict_dataset , Sequence ):
155
183
return [dataloader (ds ) for ds in predict_dataset ]
156
184
return dataloader (predict_dataset )
@@ -161,19 +189,19 @@ def predict_dataloader():
161
189
if accepts_kwargs :
162
190
special_kwargs = candidate_kwargs
163
191
else :
164
- accepted_params = set (accepted_params )
165
- accepted_params .discard ("self" )
166
- special_kwargs = {k : v for k , v in candidate_kwargs .items () if k in accepted_params }
192
+ accepted_param_names = set (accepted_params )
193
+ accepted_param_names .discard ("self" )
194
+ special_kwargs = {k : v for k , v in candidate_kwargs .items () if k in accepted_param_names }
167
195
168
196
datamodule = cls (** datamodule_kwargs , ** special_kwargs )
169
197
if train_dataset is not None :
170
- datamodule .train_dataloader = train_dataloader
198
+ datamodule .train_dataloader = train_dataloader # type: ignore[assignment]
171
199
if val_dataset is not None :
172
- datamodule .val_dataloader = val_dataloader
200
+ datamodule .val_dataloader = val_dataloader # type: ignore[assignment]
173
201
if test_dataset is not None :
174
- datamodule .test_dataloader = test_dataloader
202
+ datamodule .test_dataloader = test_dataloader # type: ignore[assignment]
175
203
if predict_dataset is not None :
176
- datamodule .predict_dataloader = predict_dataloader
204
+ datamodule .predict_dataloader = predict_dataloader # type: ignore[assignment]
177
205
return datamodule
178
206
179
207
def state_dict (self ) -> Dict [str , Any ]:
@@ -197,8 +225,8 @@ def load_from_checkpoint(
197
225
cls ,
198
226
checkpoint_path : Union [_PATH , IO ],
199
227
hparams_file : Optional [_PATH ] = None ,
200
- ** kwargs ,
201
- ):
228
+ ** kwargs : Any ,
229
+ ) -> Union [ "pl.LightningModule" , "pl.LightningDataModule" ] :
202
230
r"""
203
231
Primary way of loading a datamodule from a checkpoint. When Lightning saves a checkpoint
204
232
it stores the arguments passed to ``__init__`` in the checkpoint under ``"datamodule_hyper_parameters"``.
0 commit comments