15
15
from typing import Any , Callable , Dict , Generator , Iterable , List , Optional , Sequence , Union
16
16
17
17
import torch
18
+ from torch import Tensor
19
+ from torch .nn import Module
18
20
from torch .optim import Optimizer
19
21
from torch .utils .data import DataLoader
20
22
21
23
import pytorch_lightning as pl
22
- from pytorch_lightning .core import LightningModule
23
24
from pytorch_lightning .plugins .precision import ApexMixedPrecisionPlugin , NativeMixedPrecisionPlugin , PrecisionPlugin
24
25
from pytorch_lightning .plugins .training_type import TrainingTypePlugin
25
26
from pytorch_lightning .trainer .states import TrainerState
26
- from pytorch_lightning .utilities import rank_zero_warn
27
+ from pytorch_lightning .utilities import _NATIVE_AMP_AVAILABLE , rank_zero_warn
27
28
from pytorch_lightning .utilities .apply_func import move_data_to_device
28
29
from pytorch_lightning .utilities .enums import AMPType , GradClipAlgorithmType , LightningEnum
29
30
31
+ if _NATIVE_AMP_AVAILABLE :
32
+ from torch .cuda .amp import GradScaler
33
+
30
34
_STEP_OUTPUT_TYPE = Union [torch .Tensor , Dict [str , torch .Tensor ], None ]
31
35
32
36
33
- class Accelerator ( object ) :
37
+ class Accelerator :
34
38
"""
35
39
The Accelerator Base Class.
36
40
An Accelerator is meant to deal with one type of Hardware.
@@ -52,7 +56,6 @@ def __init__(
52
56
training_type_plugin : TrainingTypePlugin ,
53
57
) -> None :
54
58
"""
55
-
56
59
Args:
57
60
precision_plugin: the plugin to handle precision-specific parts
58
61
training_type_plugin: the plugin to handle different training routines
@@ -64,7 +67,7 @@ def __init__(
64
67
self .lr_schedulers : Sequence = []
65
68
self .optimizer_frequencies : Sequence = []
66
69
67
- def connect (self , model : LightningModule ) -> None :
70
+ def connect (self , model : 'pl. LightningModule' ) -> None :
68
71
"""Transfers ownership of the model to this plugin"""
69
72
self .training_type_plugin .connect (model )
70
73
@@ -76,7 +79,7 @@ def setup_environment(self) -> None:
76
79
"""
77
80
self .training_type_plugin .setup_environment ()
78
81
79
- def setup (self , trainer : 'pl.Trainer' , model : LightningModule ) -> None :
82
+ def setup (self , trainer : 'pl.Trainer' , model : 'pl. LightningModule' ) -> None :
80
83
"""
81
84
Setup plugins for the trainer fit and creates optimizers.
82
85
@@ -111,22 +114,22 @@ def post_dispatch(self, trainer: 'pl.Trainer') -> None:
111
114
self .precision_plugin .post_dispatch ()
112
115
113
116
@property
114
- def model (self ) -> torch .nn .Module :
115
- """Returns the model. This can also be a wrapped LightningModule.
117
+ def model (self ) -> Module :
118
+ """
119
+ Returns the model. This can also be a wrapped LightningModule.
116
120
For retrieving the pure LightningModule use :attr:`Accelerator.lightning_module`
117
-
118
121
"""
119
122
return self .training_type_plugin .model
120
123
121
124
@model .setter
122
- def model (self , new_model : torch . nn . Module ) -> None :
125
+ def model (self , new_model : Module ) -> None :
123
126
self .training_type_plugin .model = new_model
124
127
125
128
@property
126
- def lightning_module (self ) -> LightningModule :
127
- """Returns the pure LightningModule.
129
+ def lightning_module (self ) -> 'pl.LightningModule' :
130
+ """
131
+ Returns the pure LightningModule.
128
132
To get the potentially wrapped model use :attr:`Accelerator.model`
129
-
130
133
"""
131
134
return self .training_type_plugin .lightning_module
132
135
@@ -135,7 +138,8 @@ def root_device(self) -> torch.device:
135
138
return self .training_type_plugin .root_device
136
139
137
140
def teardown (self ) -> None :
138
- """This method is called to teardown the training process.
141
+ """
142
+ This method is called to teardown the training process.
139
143
It is the right place to release memory and free other ressources.
140
144
"""
141
145
pass
@@ -268,13 +272,13 @@ def validation_step_end(self, output: _STEP_OUTPUT_TYPE) -> _STEP_OUTPUT_TYPE:
268
272
269
273
def backward (
270
274
self ,
271
- closure_loss : torch . Tensor ,
275
+ closure_loss : Tensor ,
272
276
optimizer : Optimizer ,
273
277
optimizer_idx : int ,
274
278
should_accumulate : bool ,
275
279
* args : Any ,
276
280
** kwargs : Any ,
277
- ) -> torch . Tensor :
281
+ ) -> Tensor :
278
282
"""Forwards backward-calls to the precision plugin.
279
283
280
284
Args:
@@ -325,9 +329,7 @@ def clip_gradients(
325
329
gradient_clip_algorithm : GradClipAlgorithmType = GradClipAlgorithmType .NORM ,
326
330
) -> None :
327
331
"""clips all the optimizer parameters to the given value"""
328
- self .precision_plugin .clip_gradients (
329
- self .model , optimizer , clip_val , gradient_clip_algorithm = gradient_clip_algorithm
330
- )
332
+ self .precision_plugin .clip_gradients (optimizer , clip_val , gradient_clip_algorithm = gradient_clip_algorithm )
331
333
332
334
def on_train_epoch_end (self , outputs : Sequence [_STEP_OUTPUT_TYPE ]) -> None :
333
335
"""Hook to do something on the end of an training epoch
@@ -342,11 +344,11 @@ def on_train_end(self) -> None:
342
344
pass
343
345
344
346
def setup_optimizers (self , trainer : 'pl.Trainer' ) -> None :
345
- """creates optimizers and schedulers
347
+ """
348
+ Creates optimizers and schedulers
346
349
347
350
Args:
348
351
trainer: the Trainer, these optimizers should be connected to
349
- model: the model to be optimized by the created optimizers
350
352
"""
351
353
if trainer .state not in (TrainerState .FITTING , TrainerState .TUNING ):
352
354
return
@@ -357,7 +359,7 @@ def setup_optimizers(self, trainer: 'pl.Trainer') -> None:
357
359
self .lr_schedulers = lr_schedulers
358
360
self .optimizer_frequencies = optimizer_frequencies
359
361
360
- def setup_training_type_plugin (self , plugin : TrainingTypePlugin , model : LightningModule ) -> None :
362
+ def setup_training_type_plugin (self , plugin : TrainingTypePlugin , model : 'pl. LightningModule' ) -> None :
361
363
"""Attaches the training type plugin to the accelerator."""
362
364
plugin .setup (model )
363
365
@@ -390,22 +392,21 @@ def precision(self) -> Union[str, int]:
390
392
return self .precision_plugin .precision
391
393
392
394
@property
393
- def scaler (self ) -> Optional ['torch.cuda.amp.GradScaler' ]:
394
-
395
+ def scaler (self ) -> Optional ['GradScaler' ]:
395
396
return getattr (self .precision_plugin , 'scaler' , None )
396
397
397
398
@property
398
399
def rpc_enabled (self ) -> bool :
399
400
return self .training_type_plugin .rpc_enabled
400
401
401
- def optimizer_state (self , optimizer : Optimizer ) -> Dict [str , torch . Tensor ]:
402
+ def optimizer_state (self , optimizer : Optimizer ) -> Dict [str , Tensor ]:
402
403
"""
403
404
Returns state of an optimizer. Allows for syncing/collating optimizer state from processes in custom
404
405
plugins.
405
406
"""
406
407
return getattr (self .training_type_plugin , 'optimizer_state' , lambda x : x .state_dict ())(optimizer )
407
408
408
- def on_save (self , checkpoint : Dict [str , Union [Any , torch . Tensor ]]) -> Dict [str , Union [Any , torch . Tensor ]]:
409
+ def on_save (self , checkpoint : Dict [str , Union [Any , Tensor ]]) -> Dict [str , Union [Any , Tensor ]]:
409
410
return self .training_type_plugin .on_save (checkpoint )
410
411
411
412
def barrier (self , name : Optional [str ] = None ) -> None :
@@ -420,7 +421,7 @@ def broadcast(self, obj: object, src: int = 0) -> object:
420
421
"""
421
422
return self .training_type_plugin .broadcast (obj , src )
422
423
423
- def all_gather (self , tensor : torch . Tensor , group : Optional [Any ] = None , sync_grads : bool = False ) -> torch . Tensor :
424
+ def all_gather (self , tensor : Tensor , group : Optional [Any ] = None , sync_grads : bool = False ) -> Tensor :
424
425
"""
425
426
Function to gather a tensor from several distributed processes.
426
427
@@ -464,7 +465,7 @@ def model_sharded_context(self) -> Generator[None, None, None]:
464
465
yield
465
466
466
467
# todo: remove in v1.5
467
- def connect_training_type_plugin (self , plugin : TrainingTypePlugin , model : LightningModule ) -> None :
468
+ def connect_training_type_plugin (self , plugin : TrainingTypePlugin , model : 'pl. LightningModule' ) -> None :
468
469
"""
469
470
Attaches the training type plugin to the accelerator.
470
471
Also transfers ownership of the model to this plugin
0 commit comments