Skip to content

Typing for accelerators and plugins #7022

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 31 commits into from
Apr 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
9d7d440
Add typings for evaluation_loop.py
ethanwharris Apr 14, 2021
f3d5b54
Fix PEP
ethanwharris Apr 14, 2021
7b9f13b
Fix some tests
ethanwharris Apr 14, 2021
8d84fb4
Run pre-commit
ethanwharris Apr 14, 2021
9b29428
Apply suggestions from code review
Borda Apr 14, 2021
4f47b8b
Merge branch 'master' into docs/evaluation_loop_typing
carmocca Apr 14, 2021
03bde1a
Update setup.cfg
ethanwharris Apr 14, 2021
ad960bc
Fix some mypy issues
ethanwharris Apr 14, 2021
cfdbad7
Updates
ethanwharris Apr 14, 2021
e21ad93
Fix
ethanwharris Apr 14, 2021
42b60f9
Fix typing for accelerators and plugins
carmocca Apr 14, 2021
72d28a2
Merge branch 'master' into typing-accelerators-plugins
carmocca Apr 14, 2021
fdf0f0e
pre-commit
carmocca Apr 14, 2021
177f604
Fix mypy
carmocca Apr 14, 2021
75e499e
Fix typing
carmocca Apr 14, 2021
f949e33
Fix typing
carmocca Apr 14, 2021
04ff62e
Fix typing
carmocca Apr 14, 2021
49fe989
Duplicate import
carmocca Apr 14, 2021
ab32167
Fix typing
carmocca Apr 14, 2021
74d4376
Fix typing
carmocca Apr 14, 2021
429c61e
Merge branch 'master' into typing-accelerators-plugins
carmocca Apr 15, 2021
f40252c
Bad merge
carmocca Apr 15, 2021
0216d2a
Undo some changes
carmocca Apr 15, 2021
b539786
Undo forward references
carmocca Apr 15, 2021
4c93cf4
Address comment
carmocca Apr 15, 2021
a1b1247
Forward reference OSS
carmocca Apr 15, 2021
c677107
Forward reference GradScaler
carmocca Apr 15, 2021
e80cf44
Minor changes
carmocca Apr 15, 2021
6c14757
Update pytorch_lightning/accelerators/accelerator.py
carmocca Apr 15, 2021
e005c44
flake8
carmocca Apr 15, 2021
8dd9460
Update pytorch_lightning/plugins/precision/apex_amp.py
carmocca Apr 15, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `--gpus` default for parser returned by `Trainer.add_argparse_args` ([#6898](https://github.com/PyTorchLightning/pytorch-lightning/pull/6898))


- Fixed pickle error checker to now check for `pickle.PickleError` to catch all pickle errors ([#6917](https://github.com/PyTorchLightning/pytorch-lightning/pull/6917))
- Fixed pickle error checker to now check for `pickle.PickleError` to catch all pickle errors ([#6917](https://github.com/PyTorchLightning/pytorch-lightning/pull/6917))


- Fixed `AttributeError` for `require_backward_grad_sync` when running manual optimization with sharded plugin ([#6915](https://github.com/PyTorchLightning/pytorch-lightning/pull/6915))
Expand Down
57 changes: 29 additions & 28 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,26 @@
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Sequence, Union

import torch
from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer
from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pytorch_lightning.core import LightningModule
from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin
from pytorch_lightning.plugins.training_type import TrainingTypePlugin
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE, rank_zero_warn
from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.enums import AMPType, GradClipAlgorithmType, LightningEnum

if _NATIVE_AMP_AVAILABLE:
from torch.cuda.amp import GradScaler

_STEP_OUTPUT_TYPE = Union[torch.Tensor, Dict[str, torch.Tensor], None]


class Accelerator(object):
class Accelerator:
"""
The Accelerator Base Class.
An Accelerator is meant to deal with one type of Hardware.
Expand All @@ -52,7 +56,6 @@ def __init__(
training_type_plugin: TrainingTypePlugin,
) -> None:
"""

Args:
precision_plugin: the plugin to handle precision-specific parts
training_type_plugin: the plugin to handle different training routines
Expand All @@ -64,7 +67,7 @@ def __init__(
self.lr_schedulers: Sequence = []
self.optimizer_frequencies: Sequence = []

def connect(self, model: LightningModule) -> None:
def connect(self, model: 'pl.LightningModule') -> None:
"""Transfers ownership of the model to this plugin"""
self.training_type_plugin.connect(model)

Expand All @@ -76,7 +79,7 @@ def setup_environment(self) -> None:
"""
self.training_type_plugin.setup_environment()

def setup(self, trainer: 'pl.Trainer', model: LightningModule) -> None:
def setup(self, trainer: 'pl.Trainer', model: 'pl.LightningModule') -> None:
"""
Setup plugins for the trainer fit and creates optimizers.

Expand Down Expand Up @@ -111,22 +114,22 @@ def post_dispatch(self, trainer: 'pl.Trainer') -> None:
self.precision_plugin.post_dispatch()

@property
def model(self) -> torch.nn.Module:
"""Returns the model. This can also be a wrapped LightningModule.
def model(self) -> Module:
"""
Returns the model. This can also be a wrapped LightningModule.
For retrieving the pure LightningModule use :attr:`Accelerator.lightning_module`

"""
return self.training_type_plugin.model

@model.setter
def model(self, new_model: torch.nn.Module) -> None:
def model(self, new_model: Module) -> None:
self.training_type_plugin.model = new_model

@property
def lightning_module(self) -> LightningModule:
"""Returns the pure LightningModule.
def lightning_module(self) -> 'pl.LightningModule':
"""
Returns the pure LightningModule.
To get the potentially wrapped model use :attr:`Accelerator.model`

"""
return self.training_type_plugin.lightning_module

Expand All @@ -135,7 +138,8 @@ def root_device(self) -> torch.device:
return self.training_type_plugin.root_device

def teardown(self) -> None:
"""This method is called to teardown the training process.
"""
This method is called to teardown the training process.
It is the right place to release memory and free other ressources.
"""
pass
Expand Down Expand Up @@ -268,13 +272,13 @@ def validation_step_end(self, output: _STEP_OUTPUT_TYPE) -> _STEP_OUTPUT_TYPE:

def backward(
self,
closure_loss: torch.Tensor,
closure_loss: Tensor,
optimizer: Optimizer,
optimizer_idx: int,
should_accumulate: bool,
*args: Any,
**kwargs: Any,
) -> torch.Tensor:
) -> Tensor:
"""Forwards backward-calls to the precision plugin.

Args:
Expand Down Expand Up @@ -325,9 +329,7 @@ def clip_gradients(
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
) -> None:
"""clips all the optimizer parameters to the given value"""
self.precision_plugin.clip_gradients(
self.model, optimizer, clip_val, gradient_clip_algorithm=gradient_clip_algorithm
)
self.precision_plugin.clip_gradients(optimizer, clip_val, gradient_clip_algorithm=gradient_clip_algorithm)

def on_train_epoch_end(self, outputs: Sequence[_STEP_OUTPUT_TYPE]) -> None:
"""Hook to do something on the end of an training epoch
Expand All @@ -342,11 +344,11 @@ def on_train_end(self) -> None:
pass

def setup_optimizers(self, trainer: 'pl.Trainer') -> None:
"""creates optimizers and schedulers
"""
Creates optimizers and schedulers

Args:
trainer: the Trainer, these optimizers should be connected to
model: the model to be optimized by the created optimizers
"""
if trainer.state not in (TrainerState.FITTING, TrainerState.TUNING):
return
Expand All @@ -357,7 +359,7 @@ def setup_optimizers(self, trainer: 'pl.Trainer') -> None:
self.lr_schedulers = lr_schedulers
self.optimizer_frequencies = optimizer_frequencies

def setup_training_type_plugin(self, plugin: TrainingTypePlugin, model: LightningModule) -> None:
def setup_training_type_plugin(self, plugin: TrainingTypePlugin, model: 'pl.LightningModule') -> None:
"""Attaches the training type plugin to the accelerator."""
plugin.setup(model)

Expand Down Expand Up @@ -390,22 +392,21 @@ def precision(self) -> Union[str, int]:
return self.precision_plugin.precision

@property
def scaler(self) -> Optional['torch.cuda.amp.GradScaler']:

def scaler(self) -> Optional['GradScaler']:
return getattr(self.precision_plugin, 'scaler', None)

@property
def rpc_enabled(self) -> bool:
return self.training_type_plugin.rpc_enabled

def optimizer_state(self, optimizer: Optimizer) -> Dict[str, torch.Tensor]:
def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]:
"""
Returns state of an optimizer. Allows for syncing/collating optimizer state from processes in custom
plugins.
"""
return getattr(self.training_type_plugin, 'optimizer_state', lambda x: x.state_dict())(optimizer)

def on_save(self, checkpoint: Dict[str, Union[Any, torch.Tensor]]) -> Dict[str, Union[Any, torch.Tensor]]:
def on_save(self, checkpoint: Dict[str, Union[Any, Tensor]]) -> Dict[str, Union[Any, Tensor]]:
return self.training_type_plugin.on_save(checkpoint)

def barrier(self, name: Optional[str] = None) -> None:
Expand All @@ -420,7 +421,7 @@ def broadcast(self, obj: object, src: int = 0) -> object:
"""
return self.training_type_plugin.broadcast(obj, src)

def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor:
def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor:
"""
Function to gather a tensor from several distributed processes.

Expand Down Expand Up @@ -464,7 +465,7 @@ def model_sharded_context(self) -> Generator[None, None, None]:
yield

# todo: remove in v1.5
def connect_training_type_plugin(self, plugin: TrainingTypePlugin, model: LightningModule) -> None:
def connect_training_type_plugin(self, plugin: TrainingTypePlugin, model: 'pl.LightningModule') -> None:
"""
Attaches the training type plugin to the accelerator.
Also transfers ownership of the model to this plugin
Expand Down
17 changes: 7 additions & 10 deletions pytorch_lightning/accelerators/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from torch.optim import Optimizer

import pytorch_lightning as pl
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.plugins.precision import MixedPrecisionPlugin
from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin
Expand All @@ -26,10 +27,9 @@
import torch_xla.core.xla_model as xm
from torch_xla._patched_functions import clip_grad_norm_

# rename to mock in a test
xla_clip_grad_norm_ = clip_grad_norm_

import pytorch_lightning as pl


class TPUAccelerator(Accelerator):
""" Accelerator for TPU devices. """
Expand Down Expand Up @@ -59,19 +59,16 @@ def clip_gradients(
self,
optimizer: Optimizer,
clip_val: Union[float, int],
norm_type: float = 2.0,
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
) -> None:
assert gradient_clip_algorithm is GradClipAlgorithmType.NORM, \
assert gradient_clip_algorithm == GradClipAlgorithmType.NORM, \
"Only NORM gradient clipping is supported on TPU for now"

model = self.lightning_module
parameters = model.parameters()

grad_clip_val = float(clip_val)
if grad_clip_val <= 0:
return

max_norm = grad_clip_val
parameters = self.model.parameters()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this equivalent to what we had before?

norm_type = 2.0

xla_clip_grad_norm_(parameters, max_norm, norm_type)
xla_clip_grad_norm_(parameters, grad_clip_val, norm_type)
40 changes: 23 additions & 17 deletions pytorch_lightning/plugins/precision/apex_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Generator, List, Sequence, Tuple, Type
from typing import Any, Callable, ContextManager, Iterator, List, Sequence, Tuple, Type

import torch
from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer

from pytorch_lightning.core import LightningModule
from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin
from pytorch_lightning.utilities import _APEX_AVAILABLE, AMPType, rank_zero_warn
from pytorch_lightning.utilities import _APEX_AVAILABLE, AMPType

PARAMETERS = Iterator[torch.nn.Parameter]

if _APEX_AVAILABLE:
from apex import amp
Expand All @@ -32,11 +36,15 @@ def __init__(self, amp_level: str = "O2") -> None:
self.backend = AMPType.APEX
self.amp_level = amp_level

def master_params(self, optimizer: Optimizer) -> Generator[torch.Tensor, None, None]:
def master_params(self, optimizer: Optimizer) -> PARAMETERS:
return amp.master_params(optimizer)

def connect(self, model: torch.nn.Module, optimizers: Sequence[Optimizer],
lr_schedulers: Sequence[Any]) -> Tuple[torch.nn.Module, Sequence[Optimizer], Sequence[Any]]:
def connect(
self,
model: Module,
optimizers: Sequence[Optimizer],
lr_schedulers: Sequence[Any],
) -> Tuple[Module, Sequence[Optimizer], Sequence[Any]]:
"""Connects the precision plugin to the training process,
configures apex and reinits the schedulers
"""
Expand All @@ -49,28 +57,28 @@ def connect(self, model: torch.nn.Module, optimizers: Sequence[Optimizer],
def backward(
self,
model: LightningModule,
closure_loss: torch.Tensor,
closure_loss: Tensor,
optimizer: Optimizer,
opt_idx: int,
should_accumulate: bool,
*args: Any,
**kwargs: Any,
) -> torch.Tensor:
) -> Tensor:
"""performs the actual backpropagation

Args:
model: the model to be optimized
closure_loss: the loss value obtained from the closure
optimizer: the optimizer to perform the step lateron
opt_idx: the optimizer's index
opt_idx: the optimizer index
should_accumulate: whether to accumulate gradients or not

"""
closure_loss = amp.scale_loss(closure_loss, model.trainer.optimizers if optimizer is None else optimizer)
opt = model.trainer.optimizers if optimizer is None else optimizer
scaled_loss: ContextManager[Tensor] = amp.scale_loss(closure_loss, opt)

# enter apex context
context = closure_loss
closure_loss = closure_loss.__enter__()
closure_loss = scaled_loss.__enter__()

# do backward pass
# TODO: not entirely sure, why we need this
Expand All @@ -84,10 +92,8 @@ def backward(
closure_loss.backward(*args, **kwargs)

# exit amp context
a, b, c = None, None, None
error = context.__exit__(a, b, c)
error = scaled_loss.__exit__(None, None, None)
if error:
rank_zero_warn(a, b, c)
raise Exception("apex unscale error")

# once backward has been applied, release graph
Expand All @@ -97,17 +103,17 @@ def backward(
def configure_apex(
self,
amp: Type,
model: LightningModule,
model: Module,
optimizers: List[Optimizer],
amp_level: str,
) -> Tuple[LightningModule, List[Optimizer]]:
) -> Tuple[Module, List[Optimizer]]:
r"""
Override to init AMP your own way.
Must return a model and list of optimizers.

Args:
amp: pointer to amp library object.
model: pointer to current :class:`LightningModule`.
model: pointer to current :class:`torch.nn.Module`.
optimizers: list of optimizers passed in :meth:`configure_optimizers`.
amp_level: AMP mode chosen ('O1', 'O2', etc...)

Expand Down
7 changes: 3 additions & 4 deletions pytorch_lightning/plugins/precision/deepspeed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
from typing import Any, Callable, Union

import torch
from torch import Tensor
from torch.optim import Optimizer

import pytorch_lightning as pl
Expand Down Expand Up @@ -54,13 +54,13 @@ def pre_optimizer_step(
def backward(
self,
model: 'pl.LightningModule',
closure_loss: torch.Tensor,
closure_loss: Tensor,
optimizer: Optimizer,
opt_idx: int,
should_accumulate: bool,
*args: Any,
**kwargs: Any,
) -> torch.Tensor:
) -> Tensor:
if is_overridden('backward', model):
warning_cache.warn(
"Overridden backward hook in the LightningModule will be ignored since DeepSpeed handles"
Expand All @@ -76,7 +76,6 @@ def backward(

def clip_gradients(
self,
model: 'pl.LightningModule',
optimizer: Optimizer,
clip_val: Union[int, float],
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
Expand Down
Loading