Skip to content

Typing for accelerators and plugins #7022

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 31 commits into from
Apr 15, 2021
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
9d7d440
Add typings for evaluation_loop.py
ethanwharris Apr 14, 2021
f3d5b54
Fix PEP
ethanwharris Apr 14, 2021
7b9f13b
Fix some tests
ethanwharris Apr 14, 2021
8d84fb4
Run pre-commit
ethanwharris Apr 14, 2021
9b29428
Apply suggestions from code review
Borda Apr 14, 2021
4f47b8b
Merge branch 'master' into docs/evaluation_loop_typing
carmocca Apr 14, 2021
03bde1a
Update setup.cfg
ethanwharris Apr 14, 2021
ad960bc
Fix some mypy issues
ethanwharris Apr 14, 2021
cfdbad7
Updates
ethanwharris Apr 14, 2021
e21ad93
Fix
ethanwharris Apr 14, 2021
42b60f9
Fix typing for accelerators and plugins
carmocca Apr 14, 2021
72d28a2
Merge branch 'master' into typing-accelerators-plugins
carmocca Apr 14, 2021
fdf0f0e
pre-commit
carmocca Apr 14, 2021
177f604
Fix mypy
carmocca Apr 14, 2021
75e499e
Fix typing
carmocca Apr 14, 2021
f949e33
Fix typing
carmocca Apr 14, 2021
04ff62e
Fix typing
carmocca Apr 14, 2021
49fe989
Duplicate import
carmocca Apr 14, 2021
ab32167
Fix typing
carmocca Apr 14, 2021
74d4376
Fix typing
carmocca Apr 14, 2021
429c61e
Merge branch 'master' into typing-accelerators-plugins
carmocca Apr 15, 2021
f40252c
Bad merge
carmocca Apr 15, 2021
0216d2a
Undo some changes
carmocca Apr 15, 2021
b539786
Undo forward references
carmocca Apr 15, 2021
4c93cf4
Address comment
carmocca Apr 15, 2021
a1b1247
Forward reference OSS
carmocca Apr 15, 2021
c677107
Forward reference GradScaler
carmocca Apr 15, 2021
e80cf44
Minor changes
carmocca Apr 15, 2021
6c14757
Update pytorch_lightning/accelerators/accelerator.py
carmocca Apr 15, 2021
e005c44
flake8
carmocca Apr 15, 2021
8dd9460
Update pytorch_lightning/plugins/precision/apex_amp.py
carmocca Apr 15, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `--gpus` default for parser returned by `Trainer.add_argparse_args` ([#6898](https://github.com/PyTorchLightning/pytorch-lightning/pull/6898))


- Fixed pickle error checker to now check for `pickle.PickleError` to catch all pickle errors ([#6917](https://github.com/PyTorchLightning/pytorch-lightning/pull/6917))
- Fixed pickle error checker to now check for `pickle.PickleError` to catch all pickle errors ([#6917](https://github.com/PyTorchLightning/pytorch-lightning/pull/6917))


- Fixed `AttributeError` for `require_backward_grad_sync` when running manual optimization with sharded plugin ([#6915](https://github.com/PyTorchLightning/pytorch-lightning/pull/6915))
Expand Down
129 changes: 65 additions & 64 deletions pytorch_lightning/accelerators/accelerator.py

Large diffs are not rendered by default.

32 changes: 14 additions & 18 deletions pytorch_lightning/accelerators/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Union

from torch.optim import Optimizer
from typing import Any, Callable, TYPE_CHECKING, Union

from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.plugins.precision import MixedPrecisionPlugin
Expand All @@ -22,19 +20,20 @@
from pytorch_lightning.utilities import _XLA_AVAILABLE, GradClipAlgorithmType
from pytorch_lightning.utilities.exceptions import MisconfigurationException

if TYPE_CHECKING:
from torch.optim import Optimizer

from pytorch_lightning import LightningModule, Trainer

if _XLA_AVAILABLE:
import torch_xla.core.xla_model as xm
from torch_xla._patched_functions import clip_grad_norm_

xla_clip_grad_norm_ = clip_grad_norm_

import pytorch_lightning as pl


class TPUAccelerator(Accelerator):
""" Accelerator for TPU devices. """

def setup(self, trainer: 'pl.Trainer', model: 'pl.LightningModule') -> None:
def setup(self, trainer: 'Trainer', model: 'LightningModule') -> None:
"""
Raises:
MisconfigurationException:
Expand All @@ -51,27 +50,24 @@ def setup(self, trainer: 'pl.Trainer', model: 'pl.LightningModule') -> None:
return super().setup(trainer, model)

def run_optimizer_step(
self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs: Any
self, optimizer: 'Optimizer', optimizer_idx: int, lambda_closure: Callable, **kwargs: Any
) -> None:
xm.optimizer_step(optimizer, barrier=False, optimizer_args={'closure': lambda_closure, **kwargs})

def clip_gradients(
self,
optimizer: Optimizer,
optimizer: 'Optimizer',
clip_val: Union[float, int],
norm_type: float = 2.0,
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
) -> None:
assert gradient_clip_algorithm is GradClipAlgorithmType.NORM, \
assert gradient_clip_algorithm == GradClipAlgorithmType.NORM, \
"Only NORM gradient clipping is supported on TPU for now"

model = self.lightning_module
parameters = model.parameters()

grad_clip_val = float(clip_val)
if grad_clip_val <= 0:
return

max_norm = grad_clip_val
parameters = self.model.parameters()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this equivalent to what we had before?

norm_type = 2.0

xla_clip_grad_norm_(parameters, max_norm, norm_type)
clip_grad_norm_(parameters, grad_clip_val, norm_type)
48 changes: 29 additions & 19 deletions pytorch_lightning/plugins/precision/apex_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Generator, List, Sequence, Tuple, Type
from typing import Any, Callable, ContextManager, Iterator, List, Sequence, Tuple, Type, TYPE_CHECKING

import torch
from torch.optim import Optimizer

from pytorch_lightning.core import LightningModule
from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin
Expand All @@ -23,6 +22,12 @@
if _APEX_AVAILABLE:
from apex import amp

if TYPE_CHECKING:
from torch import Tensor
from torch.nn import Module, Parameter
from torch.optim import Optimizer
PARAMETERS = Iterator[Parameter]


class ApexMixedPrecisionPlugin(MixedPrecisionPlugin):
"""Mixed Precision Plugin based on Nvidia/Apex (https://github.com/NVIDIA/apex)"""
Expand All @@ -32,11 +37,15 @@ def __init__(self, amp_level: str = "O2") -> None:
self.backend = AMPType.APEX
self.amp_level = amp_level

def master_params(self, optimizer: Optimizer) -> Generator[torch.Tensor, None, None]:
def master_params(self, optimizer: 'Optimizer') -> 'PARAMETERS':
return amp.master_params(optimizer)

def connect(self, model: torch.nn.Module, optimizers: Sequence[Optimizer],
lr_schedulers: Sequence[Any]) -> Tuple[torch.nn.Module, Sequence[Optimizer], Sequence[Any]]:
def connect(
self,
model: 'Module',
optimizers: Sequence['Optimizer'],
lr_schedulers: Sequence[Any],
) -> Tuple['Module', Sequence['Optimizer'], Sequence[Any]]:
"""Connects the precision plugin to the training process,
configures apex and reinits the schedulers
"""
Expand All @@ -49,28 +58,29 @@ def connect(self, model: torch.nn.Module, optimizers: Sequence[Optimizer],
def backward(
self,
model: LightningModule,
closure_loss: torch.Tensor,
optimizer: Optimizer,
closure_loss: 'Tensor',
optimizer: 'Optimizer',
opt_idx: int,
should_accumulate: bool,
*args: Any,
**kwargs: Any,
) -> torch.Tensor:
) -> 'Tensor':
"""performs the actual backpropagation

Args:
model: the model to be optimized
closure_loss: the loss value obtained from the closure
optimizer: the optimizer to perform the step lateron
opt_idx: the optimizer's index
opt_idx: the optimizer index
should_accumulate: whether to accumulate gradients or not

"""
closure_loss = amp.scale_loss(closure_loss, model.trainer.optimizers if optimizer is None else optimizer)
scaled_loss: ContextManager['Tensor'] = amp.scale_loss(
closure_loss, model.trainer.optimizers if optimizer is None else optimizer
)

# enter apex context
context = closure_loss
closure_loss = closure_loss.__enter__()
closure_loss = scaled_loss.__enter__()

# do backward pass
# TODO: not entirely sure, why we need this
Expand All @@ -85,7 +95,7 @@ def backward(

# exit amp context
a, b, c = None, None, None
error = context.__exit__(a, b, c)
error = scaled_loss.__exit__(a, b, c)
if error:
rank_zero_warn(a, b, c)
raise Exception("apex unscale error")
Expand All @@ -97,17 +107,17 @@ def backward(
def configure_apex(
self,
amp: Type,
model: LightningModule,
optimizers: List[Optimizer],
model: 'Module',
optimizers: List['Optimizer'],
amp_level: str,
) -> Tuple[LightningModule, List[Optimizer]]:
) -> Tuple['Module', List['Optimizer']]:
r"""
Override to init AMP your own way.
Must return a model and list of optimizers.

Args:
amp: pointer to amp library object.
model: pointer to current :class:`LightningModule`.
model: pointer to current :class:`torch.nn.Module`.
optimizers: list of optimizers passed in :meth:`configure_optimizers`.
amp_level: AMP mode chosen ('O1', 'O2', etc...)

Expand All @@ -129,7 +139,7 @@ def configure_apex(self, amp, model, optimizers, amp_level):
return model, optimizers

@staticmethod
def reinit_scheduler_properties(optimizers: Sequence[Optimizer], schedulers: Sequence[Any]) -> None:
def reinit_scheduler_properties(optimizers: Sequence['Optimizer'], schedulers: Sequence[Any]) -> None:
"""Reinitializes schedulers with correct properties"""
# Reinitialize optimizer.step properties added by schedulers
for scheduler in schedulers:
Expand All @@ -153,7 +163,7 @@ def reinit_scheduler_properties(optimizers: Sequence[Optimizer], schedulers: Seq
def pre_optimizer_step(
self,
pl_module: LightningModule,
optimizer: Optimizer,
optimizer: 'Optimizer',
optimizer_idx: int,
lambda_closure: Callable,
**kwargs: Any,
Expand Down
27 changes: 14 additions & 13 deletions pytorch_lightning/plugins/precision/deepspeed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Union
from typing import Any, Callable, TYPE_CHECKING, Union

import torch
from torch.optim import Optimizer

import pytorch_lightning as pl
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.warnings import WarningCache

if TYPE_CHECKING:
from torch import Tensor
from torch.optim import Optimizer

from pytorch_lightning.core.lightning import LightningModule

warning_cache = WarningCache()


Expand All @@ -34,8 +36,8 @@ def __init__(self, precision: int) -> None:

def pre_optimizer_step(
self,
pl_module: 'pl.LightningModule',
optimizer: Optimizer,
pl_module: 'LightningModule',
optimizer: 'Optimizer',
optimizer_idx: int,
lambda_closure: Callable,
**kwargs: Any,
Expand All @@ -53,14 +55,14 @@ def pre_optimizer_step(

def backward(
self,
model: 'pl.LightningModule',
closure_loss: torch.Tensor,
optimizer: Optimizer,
model: 'LightningModule',
closure_loss: 'Tensor',
optimizer: 'Optimizer',
opt_idx: int,
should_accumulate: bool,
*args: Any,
**kwargs: Any,
) -> torch.Tensor:
) -> 'Tensor':
if is_overridden('backward', model):
warning_cache.warn(
"Overridden backward hook in the LightningModule will be ignored since DeepSpeed handles"
Expand All @@ -76,8 +78,7 @@ def backward(

def clip_gradients(
self,
model: 'pl.LightningModule',
optimizer: Optimizer,
optimizer: 'Optimizer',
clip_val: Union[int, float],
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
) -> None:
Expand Down
Loading