Skip to content

Commit f29ecbf

Browse files
authored
Typing for accelerators and plugins (#7022)
1 parent f6f81f0 commit f29ecbf

File tree

10 files changed

+95
-111
lines changed

10 files changed

+95
-111
lines changed

pytorch_lightning/accelerators/accelerator.py

Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -15,22 +15,26 @@
1515
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Sequence, Union
1616

1717
import torch
18+
from torch import Tensor
19+
from torch.nn import Module
1820
from torch.optim import Optimizer
1921
from torch.utils.data import DataLoader
2022

2123
import pytorch_lightning as pl
22-
from pytorch_lightning.core import LightningModule
2324
from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin
2425
from pytorch_lightning.plugins.training_type import TrainingTypePlugin
2526
from pytorch_lightning.trainer.states import TrainerState
26-
from pytorch_lightning.utilities import rank_zero_warn
27+
from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE, rank_zero_warn
2728
from pytorch_lightning.utilities.apply_func import move_data_to_device
2829
from pytorch_lightning.utilities.enums import AMPType, GradClipAlgorithmType, LightningEnum
2930

31+
if _NATIVE_AMP_AVAILABLE:
32+
from torch.cuda.amp import GradScaler
33+
3034
_STEP_OUTPUT_TYPE = Union[torch.Tensor, Dict[str, torch.Tensor], None]
3135

3236

33-
class Accelerator(object):
37+
class Accelerator:
3438
"""
3539
The Accelerator Base Class.
3640
An Accelerator is meant to deal with one type of Hardware.
@@ -52,7 +56,6 @@ def __init__(
5256
training_type_plugin: TrainingTypePlugin,
5357
) -> None:
5458
"""
55-
5659
Args:
5760
precision_plugin: the plugin to handle precision-specific parts
5861
training_type_plugin: the plugin to handle different training routines
@@ -64,7 +67,7 @@ def __init__(
6467
self.lr_schedulers: Sequence = []
6568
self.optimizer_frequencies: Sequence = []
6669

67-
def connect(self, model: LightningModule) -> None:
70+
def connect(self, model: 'pl.LightningModule') -> None:
6871
"""Transfers ownership of the model to this plugin"""
6972
self.training_type_plugin.connect(model)
7073

@@ -76,7 +79,7 @@ def setup_environment(self) -> None:
7679
"""
7780
self.training_type_plugin.setup_environment()
7881

79-
def setup(self, trainer: 'pl.Trainer', model: LightningModule) -> None:
82+
def setup(self, trainer: 'pl.Trainer', model: 'pl.LightningModule') -> None:
8083
"""
8184
Setup plugins for the trainer fit and creates optimizers.
8285
@@ -111,22 +114,22 @@ def post_dispatch(self, trainer: 'pl.Trainer') -> None:
111114
self.precision_plugin.post_dispatch()
112115

113116
@property
114-
def model(self) -> torch.nn.Module:
115-
"""Returns the model. This can also be a wrapped LightningModule.
117+
def model(self) -> Module:
118+
"""
119+
Returns the model. This can also be a wrapped LightningModule.
116120
For retrieving the pure LightningModule use :attr:`Accelerator.lightning_module`
117-
118121
"""
119122
return self.training_type_plugin.model
120123

121124
@model.setter
122-
def model(self, new_model: torch.nn.Module) -> None:
125+
def model(self, new_model: Module) -> None:
123126
self.training_type_plugin.model = new_model
124127

125128
@property
126-
def lightning_module(self) -> LightningModule:
127-
"""Returns the pure LightningModule.
129+
def lightning_module(self) -> 'pl.LightningModule':
130+
"""
131+
Returns the pure LightningModule.
128132
To get the potentially wrapped model use :attr:`Accelerator.model`
129-
130133
"""
131134
return self.training_type_plugin.lightning_module
132135

@@ -135,7 +138,8 @@ def root_device(self) -> torch.device:
135138
return self.training_type_plugin.root_device
136139

137140
def teardown(self) -> None:
138-
"""This method is called to teardown the training process.
141+
"""
142+
This method is called to teardown the training process.
139143
It is the right place to release memory and free other ressources.
140144
"""
141145
pass
@@ -268,13 +272,13 @@ def validation_step_end(self, output: _STEP_OUTPUT_TYPE) -> _STEP_OUTPUT_TYPE:
268272

269273
def backward(
270274
self,
271-
closure_loss: torch.Tensor,
275+
closure_loss: Tensor,
272276
optimizer: Optimizer,
273277
optimizer_idx: int,
274278
should_accumulate: bool,
275279
*args: Any,
276280
**kwargs: Any,
277-
) -> torch.Tensor:
281+
) -> Tensor:
278282
"""Forwards backward-calls to the precision plugin.
279283
280284
Args:
@@ -325,9 +329,7 @@ def clip_gradients(
325329
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
326330
) -> None:
327331
"""clips all the optimizer parameters to the given value"""
328-
self.precision_plugin.clip_gradients(
329-
self.model, optimizer, clip_val, gradient_clip_algorithm=gradient_clip_algorithm
330-
)
332+
self.precision_plugin.clip_gradients(optimizer, clip_val, gradient_clip_algorithm=gradient_clip_algorithm)
331333

332334
def on_train_epoch_end(self, outputs: Sequence[_STEP_OUTPUT_TYPE]) -> None:
333335
"""Hook to do something on the end of an training epoch
@@ -342,11 +344,11 @@ def on_train_end(self) -> None:
342344
pass
343345

344346
def setup_optimizers(self, trainer: 'pl.Trainer') -> None:
345-
"""creates optimizers and schedulers
347+
"""
348+
Creates optimizers and schedulers
346349
347350
Args:
348351
trainer: the Trainer, these optimizers should be connected to
349-
model: the model to be optimized by the created optimizers
350352
"""
351353
if trainer.state not in (TrainerState.FITTING, TrainerState.TUNING):
352354
return
@@ -357,7 +359,7 @@ def setup_optimizers(self, trainer: 'pl.Trainer') -> None:
357359
self.lr_schedulers = lr_schedulers
358360
self.optimizer_frequencies = optimizer_frequencies
359361

360-
def setup_training_type_plugin(self, plugin: TrainingTypePlugin, model: LightningModule) -> None:
362+
def setup_training_type_plugin(self, plugin: TrainingTypePlugin, model: 'pl.LightningModule') -> None:
361363
"""Attaches the training type plugin to the accelerator."""
362364
plugin.setup(model)
363365

@@ -390,22 +392,21 @@ def precision(self) -> Union[str, int]:
390392
return self.precision_plugin.precision
391393

392394
@property
393-
def scaler(self) -> Optional['torch.cuda.amp.GradScaler']:
394-
395+
def scaler(self) -> Optional['GradScaler']:
395396
return getattr(self.precision_plugin, 'scaler', None)
396397

397398
@property
398399
def rpc_enabled(self) -> bool:
399400
return self.training_type_plugin.rpc_enabled
400401

401-
def optimizer_state(self, optimizer: Optimizer) -> Dict[str, torch.Tensor]:
402+
def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]:
402403
"""
403404
Returns state of an optimizer. Allows for syncing/collating optimizer state from processes in custom
404405
plugins.
405406
"""
406407
return getattr(self.training_type_plugin, 'optimizer_state', lambda x: x.state_dict())(optimizer)
407408

408-
def on_save(self, checkpoint: Dict[str, Union[Any, torch.Tensor]]) -> Dict[str, Union[Any, torch.Tensor]]:
409+
def on_save(self, checkpoint: Dict[str, Union[Any, Tensor]]) -> Dict[str, Union[Any, Tensor]]:
409410
return self.training_type_plugin.on_save(checkpoint)
410411

411412
def barrier(self, name: Optional[str] = None) -> None:
@@ -420,7 +421,7 @@ def broadcast(self, obj: object, src: int = 0) -> object:
420421
"""
421422
return self.training_type_plugin.broadcast(obj, src)
422423

423-
def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor:
424+
def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor:
424425
"""
425426
Function to gather a tensor from several distributed processes.
426427
@@ -464,7 +465,7 @@ def model_sharded_context(self) -> Generator[None, None, None]:
464465
yield
465466

466467
# todo: remove in v1.5
467-
def connect_training_type_plugin(self, plugin: TrainingTypePlugin, model: LightningModule) -> None:
468+
def connect_training_type_plugin(self, plugin: TrainingTypePlugin, model: 'pl.LightningModule') -> None:
468469
"""
469470
Attaches the training type plugin to the accelerator.
470471
Also transfers ownership of the model to this plugin

pytorch_lightning/accelerators/tpu.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from torch.optim import Optimizer
1717

18+
import pytorch_lightning as pl
1819
from pytorch_lightning.accelerators.accelerator import Accelerator
1920
from pytorch_lightning.plugins.precision import MixedPrecisionPlugin
2021
from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin
@@ -26,10 +27,9 @@
2627
import torch_xla.core.xla_model as xm
2728
from torch_xla._patched_functions import clip_grad_norm_
2829

30+
# rename to mock in a test
2931
xla_clip_grad_norm_ = clip_grad_norm_
3032

31-
import pytorch_lightning as pl
32-
3333

3434
class TPUAccelerator(Accelerator):
3535
""" Accelerator for TPU devices. """
@@ -59,19 +59,16 @@ def clip_gradients(
5959
self,
6060
optimizer: Optimizer,
6161
clip_val: Union[float, int],
62-
norm_type: float = 2.0,
63-
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM
62+
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
6463
) -> None:
65-
assert gradient_clip_algorithm is GradClipAlgorithmType.NORM, \
64+
assert gradient_clip_algorithm == GradClipAlgorithmType.NORM, \
6665
"Only NORM gradient clipping is supported on TPU for now"
6766

68-
model = self.lightning_module
69-
parameters = model.parameters()
70-
7167
grad_clip_val = float(clip_val)
7268
if grad_clip_val <= 0:
7369
return
7470

75-
max_norm = grad_clip_val
71+
parameters = self.model.parameters()
72+
norm_type = 2.0
7673

77-
xla_clip_grad_norm_(parameters, max_norm, norm_type)
74+
xla_clip_grad_norm_(parameters, grad_clip_val, norm_type)

pytorch_lightning/plugins/precision/apex_amp.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,18 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Any, Callable, Generator, List, Sequence, Tuple, Type
14+
from typing import Any, Callable, ContextManager, Iterator, List, Sequence, Tuple, Type
1515

1616
import torch
17+
from torch import Tensor
18+
from torch.nn import Module
1719
from torch.optim import Optimizer
1820

1921
from pytorch_lightning.core import LightningModule
2022
from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin
21-
from pytorch_lightning.utilities import _APEX_AVAILABLE, AMPType, rank_zero_warn
23+
from pytorch_lightning.utilities import _APEX_AVAILABLE, AMPType
24+
25+
PARAMETERS = Iterator[torch.nn.Parameter]
2226

2327
if _APEX_AVAILABLE:
2428
from apex import amp
@@ -32,11 +36,15 @@ def __init__(self, amp_level: str = "O2") -> None:
3236
self.backend = AMPType.APEX
3337
self.amp_level = amp_level
3438

35-
def master_params(self, optimizer: Optimizer) -> Generator[torch.Tensor, None, None]:
39+
def master_params(self, optimizer: Optimizer) -> PARAMETERS:
3640
return amp.master_params(optimizer)
3741

38-
def connect(self, model: torch.nn.Module, optimizers: Sequence[Optimizer],
39-
lr_schedulers: Sequence[Any]) -> Tuple[torch.nn.Module, Sequence[Optimizer], Sequence[Any]]:
42+
def connect(
43+
self,
44+
model: Module,
45+
optimizers: Sequence[Optimizer],
46+
lr_schedulers: Sequence[Any],
47+
) -> Tuple[Module, Sequence[Optimizer], Sequence[Any]]:
4048
"""Connects the precision plugin to the training process,
4149
configures apex and reinits the schedulers
4250
"""
@@ -49,28 +57,28 @@ def connect(self, model: torch.nn.Module, optimizers: Sequence[Optimizer],
4957
def backward(
5058
self,
5159
model: LightningModule,
52-
closure_loss: torch.Tensor,
60+
closure_loss: Tensor,
5361
optimizer: Optimizer,
5462
opt_idx: int,
5563
should_accumulate: bool,
5664
*args: Any,
5765
**kwargs: Any,
58-
) -> torch.Tensor:
66+
) -> Tensor:
5967
"""performs the actual backpropagation
6068
6169
Args:
6270
model: the model to be optimized
6371
closure_loss: the loss value obtained from the closure
6472
optimizer: the optimizer to perform the step lateron
65-
opt_idx: the optimizer's index
73+
opt_idx: the optimizer index
6674
should_accumulate: whether to accumulate gradients or not
6775
6876
"""
69-
closure_loss = amp.scale_loss(closure_loss, model.trainer.optimizers if optimizer is None else optimizer)
77+
opt = model.trainer.optimizers if optimizer is None else optimizer
78+
scaled_loss: ContextManager[Tensor] = amp.scale_loss(closure_loss, opt)
7079

7180
# enter apex context
72-
context = closure_loss
73-
closure_loss = closure_loss.__enter__()
81+
closure_loss = scaled_loss.__enter__()
7482

7583
# do backward pass
7684
# TODO: not entirely sure, why we need this
@@ -84,10 +92,8 @@ def backward(
8492
closure_loss.backward(*args, **kwargs)
8593

8694
# exit amp context
87-
a, b, c = None, None, None
88-
error = context.__exit__(a, b, c)
95+
error = scaled_loss.__exit__(None, None, None)
8996
if error:
90-
rank_zero_warn(a, b, c)
9197
raise Exception("apex unscale error")
9298

9399
# once backward has been applied, release graph
@@ -97,17 +103,17 @@ def backward(
97103
def configure_apex(
98104
self,
99105
amp: Type,
100-
model: LightningModule,
106+
model: Module,
101107
optimizers: List[Optimizer],
102108
amp_level: str,
103-
) -> Tuple[LightningModule, List[Optimizer]]:
109+
) -> Tuple[Module, List[Optimizer]]:
104110
r"""
105111
Override to init AMP your own way.
106112
Must return a model and list of optimizers.
107113
108114
Args:
109115
amp: pointer to amp library object.
110-
model: pointer to current :class:`LightningModule`.
116+
model: pointer to current :class:`torch.nn.Module`.
111117
optimizers: list of optimizers passed in :meth:`configure_optimizers`.
112118
amp_level: AMP mode chosen ('O1', 'O2', etc...)
113119

pytorch_lightning/plugins/precision/deepspeed_precision.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
from typing import Any, Callable, Union
1515

16-
import torch
16+
from torch import Tensor
1717
from torch.optim import Optimizer
1818

1919
import pytorch_lightning as pl
@@ -54,13 +54,13 @@ def pre_optimizer_step(
5454
def backward(
5555
self,
5656
model: 'pl.LightningModule',
57-
closure_loss: torch.Tensor,
57+
closure_loss: Tensor,
5858
optimizer: Optimizer,
5959
opt_idx: int,
6060
should_accumulate: bool,
6161
*args: Any,
6262
**kwargs: Any,
63-
) -> torch.Tensor:
63+
) -> Tensor:
6464
if is_overridden('backward', model):
6565
warning_cache.warn(
6666
"Overridden backward hook in the LightningModule will be ignored since DeepSpeed handles"
@@ -76,7 +76,6 @@ def backward(
7676

7777
def clip_gradients(
7878
self,
79-
model: 'pl.LightningModule',
8079
optimizer: Optimizer,
8180
clip_val: Union[int, float],
8281
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,

0 commit comments

Comments
 (0)