12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
import contextlib
15
- from typing import Any , Callable , Dict , Generator , Iterable , List , Optional , Sequence , TYPE_CHECKING , Union
15
+ from typing import Any , Callable , Dict , Generator , Iterable , List , Optional , Sequence , Union
16
16
17
17
import torch
18
18
from torch .optim import Optimizer
19
19
from torch .utils .data import DataLoader
20
20
21
+ import pytorch_lightning as pl
21
22
from pytorch_lightning .core import LightningModule
22
23
from pytorch_lightning .plugins .precision import ApexMixedPrecisionPlugin , NativeMixedPrecisionPlugin , PrecisionPlugin
23
24
from pytorch_lightning .plugins .training_type import TrainingTypePlugin
26
27
from pytorch_lightning .utilities .apply_func import move_data_to_device
27
28
from pytorch_lightning .utilities .enums import AMPType , GradClipAlgorithmType , LightningEnum
28
29
29
- if TYPE_CHECKING :
30
- from torch .cuda .amp import GradScaler
31
-
32
- from pytorch_lightning .trainer .trainer import Trainer
33
-
34
30
_STEP_OUTPUT_TYPE = Union [torch .Tensor , Dict [str , torch .Tensor ], None ]
35
31
36
32
@@ -40,6 +36,7 @@ class Accelerator(object):
40
36
An Accelerator is meant to deal with one type of Hardware.
41
37
42
38
Currently there are accelerators for:
39
+
43
40
- CPU
44
41
- GPU
45
42
- TPU
@@ -79,9 +76,10 @@ def setup_environment(self) -> None:
79
76
"""
80
77
self .training_type_plugin .setup_environment ()
81
78
82
- def setup (self , trainer : 'Trainer' , model : LightningModule ) -> None :
79
+ def setup (self , trainer : 'pl. Trainer' , model : LightningModule ) -> None :
83
80
"""
84
81
Setup plugins for the trainer fit and creates optimizers.
82
+
85
83
Args:
86
84
trainer: the trainer instance
87
85
model: the LightningModule
@@ -91,23 +89,23 @@ def setup(self, trainer: 'Trainer', model: LightningModule) -> None:
91
89
self .setup_optimizers (trainer )
92
90
self .setup_precision_plugin (self .precision_plugin )
93
91
94
- def start_training (self , trainer : 'Trainer' ) -> None :
92
+ def start_training (self , trainer : 'pl. Trainer' ) -> None :
95
93
self .training_type_plugin .start_training (trainer )
96
94
97
- def start_evaluating (self , trainer : 'Trainer' ) -> None :
95
+ def start_evaluating (self , trainer : 'pl. Trainer' ) -> None :
98
96
self .training_type_plugin .start_evaluating (trainer )
99
97
100
- def start_predicting (self , trainer : 'Trainer' ) -> None :
98
+ def start_predicting (self , trainer : 'pl. Trainer' ) -> None :
101
99
self .training_type_plugin .start_predicting (trainer )
102
100
103
- def pre_dispatch (self , trainer : 'Trainer' ) -> None :
101
+ def pre_dispatch (self , trainer : 'pl. Trainer' ) -> None :
104
102
"""Hook to do something before the training/evaluation/prediction starts."""
105
103
self .training_type_plugin .pre_dispatch ()
106
104
if self .training_type_plugin .setup_optimizers_in_pre_dispatch :
107
105
self .setup_optimizers (trainer )
108
106
self .precision_plugin .pre_dispatch ()
109
107
110
- def post_dispatch (self , trainer : 'Trainer' ) -> None :
108
+ def post_dispatch (self , trainer : 'pl. Trainer' ) -> None :
111
109
"""Hook to do something before the training/evaluation/prediction starts."""
112
110
self .training_type_plugin .post_dispatch ()
113
111
self .precision_plugin .post_dispatch ()
@@ -169,12 +167,13 @@ def training_step(
169
167
170
168
Args:
171
169
args: the arguments for the models training step. Can consist of the following:
172
- batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
173
- The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
174
- batch_idx (int): Integer displaying index of this batch
175
- optimizer_idx (int): When using multiple optimizers, this argument will also be present.
176
- hiddens(:class:`~torch.Tensor`): Passed in if
177
- :paramref:`~pytorch_lightning.trainer.trainer.Trainer.truncated_bptt_steps` > 0.
170
+
171
+ - batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
172
+ The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
173
+ - batch_idx (int): Integer displaying index of this batch
174
+ - optimizer_idx (int): When using multiple optimizers, this argument will also be present.
175
+ - hiddens(:class:`~torch.Tensor`): Passed in if
176
+ :paramref:`~pytorch_lightning.trainer.trainer.Trainer.truncated_bptt_steps` > 0.
178
177
179
178
"""
180
179
args [0 ] = self .to_device (args [0 ])
@@ -190,11 +189,12 @@ def validation_step(self, args: List[Union[Any, int]]) -> _STEP_OUTPUT_TYPE:
190
189
191
190
Args:
192
191
args: the arguments for the models validation step. Can consist of the following:
193
- batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
194
- The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
195
- batch_idx (int): The index of this batch
196
- dataloader_idx (int): The index of the dataloader that produced this batch
197
- (only if multiple val dataloaders used)
192
+
193
+ - batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
194
+ The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
195
+ - batch_idx (int): The index of this batch
196
+ - dataloader_idx (int): The index of the dataloader that produced this batch
197
+ (only if multiple val dataloaders used)
198
198
"""
199
199
batch = self .to_device (args [0 ])
200
200
@@ -208,11 +208,12 @@ def test_step(self, args: List[Union[Any, int]]) -> _STEP_OUTPUT_TYPE:
208
208
209
209
Args:
210
210
args: the arguments for the models test step. Can consist of the following:
211
- batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
212
- The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
213
- batch_idx (int): The index of this batch.
214
- dataloader_idx (int): The index of the dataloader that produced this batch
215
- (only if multiple test dataloaders used).
211
+
212
+ - batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
213
+ The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
214
+ - batch_idx (int): The index of this batch.
215
+ - dataloader_idx (int): The index of the dataloader that produced this batch
216
+ (only if multiple test dataloaders used).
216
217
"""
217
218
batch = self .to_device (args [0 ])
218
219
@@ -226,11 +227,13 @@ def predict_step(self, args: List[Union[Any, int]]) -> _STEP_OUTPUT_TYPE:
226
227
227
228
Args:
228
229
args: the arguments for the models predict step. Can consist of the following:
229
- batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
230
- The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
231
- batch_idx (int): The index of this batch.
232
- dataloader_idx (int): The index of the dataloader that produced this batch
233
- (only if multiple predict dataloaders used).
230
+
231
+ - batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
232
+ The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
233
+ - batch_idx (int): The index of this batch.
234
+ - dataloader_idx (int): The index of the dataloader that produced this batch
235
+ (only if multiple predict dataloaders used).
236
+
234
237
"""
235
238
batch = self .to_device (args [0 ])
236
239
@@ -336,7 +339,7 @@ def on_train_end(self) -> None:
336
339
"""Hook to do something at the end of the training"""
337
340
pass
338
341
339
- def setup_optimizers (self , trainer : 'Trainer' ) -> None :
342
+ def setup_optimizers (self , trainer : 'pl. Trainer' ) -> None :
340
343
"""creates optimizers and schedulers
341
344
342
345
Args:
@@ -385,7 +388,7 @@ def precision(self) -> Union[str, int]:
385
388
return self .precision_plugin .precision
386
389
387
390
@property
388
- def scaler (self ) -> Optional ['GradScaler' ]:
391
+ def scaler (self ) -> Optional ['torch.cuda.amp. GradScaler' ]:
389
392
390
393
return getattr (self .precision_plugin , 'scaler' , None )
391
394
@@ -423,6 +426,7 @@ def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_gra
423
426
tensor: tensor of shape (batch, ...)
424
427
group: the process group to gather results from. Defaults to all processes (world)
425
428
sync_grads: flag that allows users to synchronize gradients for all_gather op
429
+
426
430
Return:
427
431
A tensor of shape (world_size, batch, ...)
428
432
"""
@@ -451,7 +455,8 @@ def model_sharded_context(self) -> Generator[None, None, None]:
451
455
shard the model instantly - useful for extremely large models. Can save memory and
452
456
initialization time.
453
457
454
- Returns: Model parallel context.
458
+ Returns:
459
+ Model parallel context.
455
460
"""
456
461
with self .training_type_plugin .model_sharded_context ():
457
462
yield
@@ -498,7 +503,9 @@ def call_configure_sharded_model_hook(self) -> bool:
498
503
"""
499
504
Allow model parallel hook to be called in suitable environments determined by the training type plugin.
500
505
This is useful for when we want to shard the model once within fit.
501
- Returns: True if we want to call the model parallel setup hook.
506
+
507
+ Returns:
508
+ True if we want to call the model parallel setup hook.
502
509
"""
503
510
return self .training_type_plugin .call_configure_sharded_model_hook
504
511
@@ -512,7 +519,9 @@ def setup_optimizers_in_pre_dispatch(self) -> bool:
512
519
Override to delay setting optimizers and schedulers till after dispatch.
513
520
This is useful when the `TrainingTypePlugin` requires operating on the wrapped accelerator model.
514
521
However this may break certain precision plugins such as APEX which require optimizers to be set.
515
- Returns: If True, delay setup optimizers till pre_dispatch, else call within setup.
522
+
523
+ Returns:
524
+ If True, delay setup optimizers until `pre_dispatch`, else call within `setup`.
516
525
"""
517
526
return self .training_type_plugin .setup_optimizers_in_pre_dispatch
518
527
0 commit comments