Skip to content

Commit 20ff50c

Browse files
awaelchliananthsubwilliamFalconBorda
authored
Accelerator API docs (#6936)
Co-authored-by: ananthsub <[email protected]> Co-authored-by: William Falcon <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
1 parent b85cfbe commit 20ff50c

File tree

7 files changed

+125
-60
lines changed

7 files changed

+125
-60
lines changed

docs/source/api_references.rst

+15
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,21 @@
11
API References
22
==============
33

4+
Accelerator API
5+
---------------
6+
7+
.. currentmodule:: pytorch_lightning.accelerators
8+
9+
.. autosummary::
10+
:toctree: api
11+
:nosignatures:
12+
:template: classtemplate.rst
13+
14+
Accelerator
15+
CPUAccelerator
16+
GPUAccelerator
17+
TPUAccelerator
18+
419
Core API
520
--------
621

+50-4
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,56 @@
1+
.. _accelerators:
2+
13
############
24
Accelerators
35
############
46
Accelerators connect a Lightning Trainer to arbitrary accelerators (CPUs, GPUs, TPUs, etc). Accelerators
5-
also manage distributed accelerators (like DP, DDP, HPC cluster).
6-
7-
Accelerators can also be configured to run on arbitrary clusters using Plugins or to link up to arbitrary
7+
also manage distributed communication through :ref:`Plugins` (like DP, DDP, HPC cluster) and
8+
can also be configured to run on arbitrary clusters or to link up to arbitrary
89
computational strategies like 16-bit precision via AMP and Apex.
910

10-
**For help setting up custom plugin/accelerator please reach out to us at [email protected]**
11+
An Accelerator is meant to deal with one type of hardware.
12+
Currently there are accelerators for:
13+
14+
- CPU
15+
- GPU
16+
- TPU
17+
18+
Each Accelerator gets two plugins upon initialization:
19+
One to handle differences from the training routine and one to handle different precisions.
20+
21+
.. testcode::
22+
23+
from pytorch_lightning import Trainer
24+
from pytorch_lightning.accelerators import GPUAccelerator
25+
from pytorch_lightning.plugins import NativeMixedPrecisionPlugin, DDPPlugin
26+
27+
accelerator = GPUAccelerator(
28+
precision_plugin=NativeMixedPrecisionPlugin(),
29+
training_type_plugin=DDPPlugin(),
30+
)
31+
trainer = Trainer(accelerator=accelerator)
32+
33+
34+
We expose Accelerators and Plugins mainly for expert users who want to extend Lightning to work with new
35+
hardware and distributed training or clusters.
36+
37+
38+
.. warning:: The Accelerator API is in beta and subject to change.
39+
For help setting up custom plugins/accelerators, please reach out to us at **[email protected]**
40+
41+
----------
42+
43+
44+
Accelerator API
45+
---------------
46+
47+
.. currentmodule:: pytorch_lightning.accelerators
48+
49+
.. autosummary::
50+
:nosignatures:
51+
:template: classtemplate.rst
52+
53+
Accelerator
54+
CPUAccelerator
55+
GPUAccelerator
56+
TPUAccelerator

docs/source/extensions/plugins.rst

+2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
.. _plugins:
2+
13
#######
24
Plugins
35
#######

pytorch_lightning/accelerators/accelerator.py

+47-38
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import contextlib
15-
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Sequence, TYPE_CHECKING, Union
15+
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Sequence, Union
1616

1717
import torch
1818
from torch.optim import Optimizer
1919
from torch.utils.data import DataLoader
2020

21+
import pytorch_lightning as pl
2122
from pytorch_lightning.core import LightningModule
2223
from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin
2324
from pytorch_lightning.plugins.training_type import TrainingTypePlugin
@@ -26,11 +27,6 @@
2627
from pytorch_lightning.utilities.apply_func import move_data_to_device
2728
from pytorch_lightning.utilities.enums import AMPType, GradClipAlgorithmType, LightningEnum
2829

29-
if TYPE_CHECKING:
30-
from torch.cuda.amp import GradScaler
31-
32-
from pytorch_lightning.trainer.trainer import Trainer
33-
3430
_STEP_OUTPUT_TYPE = Union[torch.Tensor, Dict[str, torch.Tensor], None]
3531

3632

@@ -40,6 +36,7 @@ class Accelerator(object):
4036
An Accelerator is meant to deal with one type of Hardware.
4137
4238
Currently there are accelerators for:
39+
4340
- CPU
4441
- GPU
4542
- TPU
@@ -79,9 +76,10 @@ def setup_environment(self) -> None:
7976
"""
8077
self.training_type_plugin.setup_environment()
8178

82-
def setup(self, trainer: 'Trainer', model: LightningModule) -> None:
79+
def setup(self, trainer: 'pl.Trainer', model: LightningModule) -> None:
8380
"""
8481
Setup plugins for the trainer fit and creates optimizers.
82+
8583
Args:
8684
trainer: the trainer instance
8785
model: the LightningModule
@@ -91,23 +89,23 @@ def setup(self, trainer: 'Trainer', model: LightningModule) -> None:
9189
self.setup_optimizers(trainer)
9290
self.setup_precision_plugin(self.precision_plugin)
9391

94-
def start_training(self, trainer: 'Trainer') -> None:
92+
def start_training(self, trainer: 'pl.Trainer') -> None:
9593
self.training_type_plugin.start_training(trainer)
9694

97-
def start_evaluating(self, trainer: 'Trainer') -> None:
95+
def start_evaluating(self, trainer: 'pl.Trainer') -> None:
9896
self.training_type_plugin.start_evaluating(trainer)
9997

100-
def start_predicting(self, trainer: 'Trainer') -> None:
98+
def start_predicting(self, trainer: 'pl.Trainer') -> None:
10199
self.training_type_plugin.start_predicting(trainer)
102100

103-
def pre_dispatch(self, trainer: 'Trainer') -> None:
101+
def pre_dispatch(self, trainer: 'pl.Trainer') -> None:
104102
"""Hook to do something before the training/evaluation/prediction starts."""
105103
self.training_type_plugin.pre_dispatch()
106104
if self.training_type_plugin.setup_optimizers_in_pre_dispatch:
107105
self.setup_optimizers(trainer)
108106
self.precision_plugin.pre_dispatch()
109107

110-
def post_dispatch(self, trainer: 'Trainer') -> None:
108+
def post_dispatch(self, trainer: 'pl.Trainer') -> None:
111109
"""Hook to do something before the training/evaluation/prediction starts."""
112110
self.training_type_plugin.post_dispatch()
113111
self.precision_plugin.post_dispatch()
@@ -169,12 +167,13 @@ def training_step(
169167
170168
Args:
171169
args: the arguments for the models training step. Can consist of the following:
172-
batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
173-
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
174-
batch_idx (int): Integer displaying index of this batch
175-
optimizer_idx (int): When using multiple optimizers, this argument will also be present.
176-
hiddens(:class:`~torch.Tensor`): Passed in if
177-
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.truncated_bptt_steps` > 0.
170+
171+
- batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
172+
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
173+
- batch_idx (int): Integer displaying index of this batch
174+
- optimizer_idx (int): When using multiple optimizers, this argument will also be present.
175+
- hiddens(:class:`~torch.Tensor`): Passed in if
176+
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.truncated_bptt_steps` > 0.
178177
179178
"""
180179
args[0] = self.to_device(args[0])
@@ -190,11 +189,12 @@ def validation_step(self, args: List[Union[Any, int]]) -> _STEP_OUTPUT_TYPE:
190189
191190
Args:
192191
args: the arguments for the models validation step. Can consist of the following:
193-
batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
194-
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
195-
batch_idx (int): The index of this batch
196-
dataloader_idx (int): The index of the dataloader that produced this batch
197-
(only if multiple val dataloaders used)
192+
193+
- batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
194+
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
195+
- batch_idx (int): The index of this batch
196+
- dataloader_idx (int): The index of the dataloader that produced this batch
197+
(only if multiple val dataloaders used)
198198
"""
199199
batch = self.to_device(args[0])
200200

@@ -208,11 +208,12 @@ def test_step(self, args: List[Union[Any, int]]) -> _STEP_OUTPUT_TYPE:
208208
209209
Args:
210210
args: the arguments for the models test step. Can consist of the following:
211-
batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
212-
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
213-
batch_idx (int): The index of this batch.
214-
dataloader_idx (int): The index of the dataloader that produced this batch
215-
(only if multiple test dataloaders used).
211+
212+
- batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
213+
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
214+
- batch_idx (int): The index of this batch.
215+
- dataloader_idx (int): The index of the dataloader that produced this batch
216+
(only if multiple test dataloaders used).
216217
"""
217218
batch = self.to_device(args[0])
218219

@@ -226,11 +227,13 @@ def predict_step(self, args: List[Union[Any, int]]) -> _STEP_OUTPUT_TYPE:
226227
227228
Args:
228229
args: the arguments for the models predict step. Can consist of the following:
229-
batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
230-
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
231-
batch_idx (int): The index of this batch.
232-
dataloader_idx (int): The index of the dataloader that produced this batch
233-
(only if multiple predict dataloaders used).
230+
231+
- batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
232+
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
233+
- batch_idx (int): The index of this batch.
234+
- dataloader_idx (int): The index of the dataloader that produced this batch
235+
(only if multiple predict dataloaders used).
236+
234237
"""
235238
batch = self.to_device(args[0])
236239

@@ -336,7 +339,7 @@ def on_train_end(self) -> None:
336339
"""Hook to do something at the end of the training"""
337340
pass
338341

339-
def setup_optimizers(self, trainer: 'Trainer') -> None:
342+
def setup_optimizers(self, trainer: 'pl.Trainer') -> None:
340343
"""creates optimizers and schedulers
341344
342345
Args:
@@ -385,7 +388,7 @@ def precision(self) -> Union[str, int]:
385388
return self.precision_plugin.precision
386389

387390
@property
388-
def scaler(self) -> Optional['GradScaler']:
391+
def scaler(self) -> Optional['torch.cuda.amp.GradScaler']:
389392

390393
return getattr(self.precision_plugin, 'scaler', None)
391394

@@ -423,6 +426,7 @@ def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_gra
423426
tensor: tensor of shape (batch, ...)
424427
group: the process group to gather results from. Defaults to all processes (world)
425428
sync_grads: flag that allows users to synchronize gradients for all_gather op
429+
426430
Return:
427431
A tensor of shape (world_size, batch, ...)
428432
"""
@@ -451,7 +455,8 @@ def model_sharded_context(self) -> Generator[None, None, None]:
451455
shard the model instantly - useful for extremely large models. Can save memory and
452456
initialization time.
453457
454-
Returns: Model parallel context.
458+
Returns:
459+
Model parallel context.
455460
"""
456461
with self.training_type_plugin.model_sharded_context():
457462
yield
@@ -498,7 +503,9 @@ def call_configure_sharded_model_hook(self) -> bool:
498503
"""
499504
Allow model parallel hook to be called in suitable environments determined by the training type plugin.
500505
This is useful for when we want to shard the model once within fit.
501-
Returns: True if we want to call the model parallel setup hook.
506+
507+
Returns:
508+
True if we want to call the model parallel setup hook.
502509
"""
503510
return self.training_type_plugin.call_configure_sharded_model_hook
504511

@@ -512,7 +519,9 @@ def setup_optimizers_in_pre_dispatch(self) -> bool:
512519
Override to delay setting optimizers and schedulers till after dispatch.
513520
This is useful when the `TrainingTypePlugin` requires operating on the wrapped accelerator model.
514521
However this may break certain precision plugins such as APEX which require optimizers to be set.
515-
Returns: If True, delay setup optimizers till pre_dispatch, else call within setup.
522+
523+
Returns:
524+
If True, delay setup optimizers until `pre_dispatch`, else call within `setup`.
516525
"""
517526
return self.training_type_plugin.setup_optimizers_in_pre_dispatch
518527

pytorch_lightning/accelerators/cpu.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,16 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import TYPE_CHECKING
15-
14+
import pytorch_lightning as pl
1615
from pytorch_lightning.accelerators.accelerator import Accelerator
1716
from pytorch_lightning.plugins.precision import MixedPrecisionPlugin
1817
from pytorch_lightning.utilities.exceptions import MisconfigurationException
1918

20-
if TYPE_CHECKING:
21-
from pytorch_lightning.core.lightning import LightningModule
22-
from pytorch_lightning.trainer.trainer import Trainer
23-
2419

2520
class CPUAccelerator(Accelerator):
21+
""" Accelerator for CPU devices. """
2622

27-
def setup(self, trainer: 'Trainer', model: 'LightningModule') -> None:
23+
def setup(self, trainer: 'pl.Trainer', model: 'pl.LightningModule') -> None:
2824
"""
2925
Raises:
3026
MisconfigurationException:

pytorch_lightning/accelerators/gpu.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -13,24 +13,22 @@
1313
# limitations under the License.
1414
import logging
1515
import os
16-
from typing import Any, TYPE_CHECKING
16+
from typing import Any
1717

1818
import torch
1919

20+
import pytorch_lightning as pl
2021
from pytorch_lightning.accelerators.accelerator import Accelerator
2122
from pytorch_lightning.plugins import DataParallelPlugin
2223
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2324

24-
if TYPE_CHECKING:
25-
from pytorch_lightning.core.lightning import LightningModule
26-
from pytorch_lightning.trainer.trainer import Trainer
27-
2825
_log = logging.getLogger(__name__)
2926

3027

3128
class GPUAccelerator(Accelerator):
29+
""" Accelerator for GPU devices. """
3230

33-
def setup(self, trainer: 'Trainer', model: 'LightningModule') -> None:
31+
def setup(self, trainer: 'pl.Trainer', model: 'pl.LightningModule') -> None:
3432
"""
3533
Raises:
3634
MisconfigurationException:

pytorch_lightning/accelerators/tpu.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Any, Callable, TYPE_CHECKING, Union
14+
from typing import Any, Callable, Union
1515

1616
from torch.optim import Optimizer
1717

@@ -28,14 +28,13 @@
2828

2929
xla_clip_grad_norm_ = clip_grad_norm_
3030

31-
if TYPE_CHECKING:
32-
from pytorch_lightning.core.lightning import LightningModule
33-
from pytorch_lightning.trainer.trainer import Trainer
31+
import pytorch_lightning as pl
3432

3533

3634
class TPUAccelerator(Accelerator):
35+
""" Accelerator for TPU devices. """
3736

38-
def setup(self, trainer: 'Trainer', model: 'LightningModule') -> None:
37+
def setup(self, trainer: 'pl.Trainer', model: 'pl.LightningModule') -> None:
3938
"""
4039
Raises:
4140
MisconfigurationException:

0 commit comments

Comments
 (0)