Skip to content

Typing for Lightning Package #7035

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 34 commits into from
Closed
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
1fc6c3f
update typing accelerators
justusschock Apr 13, 2021
a1da70e
add typing to early stopping and pruning
justusschock Apr 14, 2021
f90ea1a
type model_checkpoint
justusschock Apr 14, 2021
d3599b3
add typing lr monitor
justusschock Apr 14, 2021
100968f
add typing
justusschock Apr 14, 2021
57f18d1
typing gradient accumulation callback
justusschock Apr 14, 2021
6f4a397
type gpu stats monitor
justusschock Apr 15, 2021
4a811b5
add typing to swa
justusschock Apr 15, 2021
5fbea74
type quantization
justusschock Apr 14, 2021
f6a3fe5
add typing for progbar
justusschock Apr 14, 2021
31dec82
type datamodule
justusschock Apr 14, 2021
1baeeba
type decorators
justusschock Apr 14, 2021
cf21114
type hooks
justusschock Apr 14, 2021
51e3fe8
type memory
justusschock Apr 14, 2021
75e55f3
type lightning optimizer
justusschock Apr 14, 2021
8766f5b
type saving.py
justusschock Apr 14, 2021
2b1fc0d
type results
justusschock Apr 14, 2021
297949b
fix typing of core
justusschock Apr 15, 2021
5f79ef8
type distributed package
justusschock Apr 15, 2021
014736b
type base logger
justusschock Apr 15, 2021
b40a223
type comet logger
justusschock Apr 15, 2021
6ae518d
type csv logger
justusschock Apr 15, 2021
b98fb0e
type mlflow logger
justusschock Apr 15, 2021
f5170d0
type neptune logger
justusschock Apr 15, 2021
558ca08
add typing for tensorboard logger
justusschock Apr 15, 2021
967650c
type test tube logger
justusschock Apr 15, 2021
04fc428
type wandb
justusschock Apr 15, 2021
edc4d8b
remove now unused type ignore
justusschock Apr 15, 2021
750470e
fix annotation
justusschock Apr 15, 2021
549ba31
unused imports
justusschock Apr 15, 2021
b3758ac
fix annotations with conditional imports
justusschock Apr 15, 2021
692c219
fix really stupid bug
justusschock Apr 15, 2021
73519e6
pre-commit
justusschock Apr 15, 2021
67bf8c9
Apply suggestions from code review
Borda Apr 15, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `--gpus` default for parser returned by `Trainer.add_argparse_args` ([#6898](https://github.com/PyTorchLightning/pytorch-lightning/pull/6898))


- Fixed pickle error checker to now check for `pickle.PickleError` to catch all pickle errors ([#6917](https://github.com/PyTorchLightning/pytorch-lightning/pull/6917))
- Fixed pickle error checker to now check for `pickle.PickleError` to catch all pickle errors ([#6917](https://github.com/PyTorchLightning/pytorch-lightning/pull/6917))


- Fixed `AttributeError` for `require_backward_grad_sync` when running manual optimization with sharded plugin ([#6915](https://github.com/PyTorchLightning/pytorch-lightning/pull/6915))
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/__about__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import time

_this_year = time.strftime("%Y")
__version__ = '1.3.0rc1'
__version__ = "20210414"
__author__ = 'William Falcon et al.'
__author_email__ = '[email protected]'
__license__ = 'Apache-2.0'
Expand Down
34 changes: 18 additions & 16 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Sequence, Union
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Sequence, TYPE_CHECKING, Union

import torch
from torch.optim import Optimizer
from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pytorch_lightning.core import LightningModule
from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin
from pytorch_lightning.plugins.training_type import TrainingTypePlugin
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.enums import AMPType, GradClipAlgorithmType, LightningEnum
from pytorch_lightning.utilities.enums import AMPType, GradClipAlgorithmType

if TYPE_CHECKING:
from pytorch_lightning.core import LightningModule
from pytorch_lightning.trainer.trainer import Trainer

_STEP_OUTPUT_TYPE = Union[torch.Tensor, Dict[str, torch.Tensor], None]

Expand Down Expand Up @@ -64,7 +66,7 @@ def __init__(
self.lr_schedulers: Sequence = []
self.optimizer_frequencies: Sequence = []

def connect(self, model: LightningModule) -> None:
def connect(self, model: 'LightningModule') -> None:
"""Transfers ownership of the model to this plugin"""
self.training_type_plugin.connect(model)

Expand All @@ -76,7 +78,7 @@ def setup_environment(self) -> None:
"""
self.training_type_plugin.setup_environment()

def setup(self, trainer: 'pl.Trainer', model: LightningModule) -> None:
def setup(self, trainer: 'Trainer', model: 'LightningModule') -> None:
"""
Setup plugins for the trainer fit and creates optimizers.

Expand All @@ -89,23 +91,23 @@ def setup(self, trainer: 'pl.Trainer', model: LightningModule) -> None:
self.setup_optimizers(trainer)
self.setup_precision_plugin(self.precision_plugin)

def start_training(self, trainer: 'pl.Trainer') -> None:
def start_training(self, trainer: 'Trainer') -> None:
self.training_type_plugin.start_training(trainer)

def start_evaluating(self, trainer: 'pl.Trainer') -> None:
def start_evaluating(self, trainer: 'Trainer') -> None:
self.training_type_plugin.start_evaluating(trainer)

def start_predicting(self, trainer: 'pl.Trainer') -> None:
def start_predicting(self, trainer: 'Trainer') -> None:
self.training_type_plugin.start_predicting(trainer)

def pre_dispatch(self, trainer: 'pl.Trainer') -> None:
def pre_dispatch(self, trainer: 'Trainer') -> None:
"""Hook to do something before the training/evaluation/prediction starts."""
self.training_type_plugin.pre_dispatch()
if self.training_type_plugin.setup_optimizers_in_pre_dispatch:
self.setup_optimizers(trainer)
self.precision_plugin.pre_dispatch()

def post_dispatch(self, trainer: 'pl.Trainer') -> None:
def post_dispatch(self, trainer: 'Trainer') -> None:
"""Hook to do something after the training/evaluation/prediction starts."""
self.training_type_plugin.post_dispatch()
self.precision_plugin.post_dispatch()
Expand All @@ -123,7 +125,7 @@ def model(self, new_model: torch.nn.Module) -> None:
self.training_type_plugin.model = new_model

@property
def lightning_module(self) -> LightningModule:
def lightning_module(self) -> 'LightningModule':
"""Returns the pure LightningModule.
To get the potentially wrapped model use :attr:`Accelerator.model`

Expand Down Expand Up @@ -341,7 +343,7 @@ def on_train_end(self) -> None:
"""Hook to do something at the end of the training"""
pass

def setup_optimizers(self, trainer: 'pl.Trainer') -> None:
def setup_optimizers(self, trainer: 'Trainer') -> None:
"""creates optimizers and schedulers

Args:
Expand All @@ -357,7 +359,7 @@ def setup_optimizers(self, trainer: 'pl.Trainer') -> None:
self.lr_schedulers = lr_schedulers
self.optimizer_frequencies = optimizer_frequencies

def setup_training_type_plugin(self, plugin: TrainingTypePlugin, model: LightningModule) -> None:
def setup_training_type_plugin(self, plugin: TrainingTypePlugin, model: 'LightningModule') -> None:
"""Attaches the training type plugin to the accelerator."""
plugin.setup(model)

Expand All @@ -378,7 +380,7 @@ def to_device(self, batch: Any) -> Any:
return batch[0] if is_dict else batch

@property
def amp_backend(self) -> Optional[LightningEnum]:
def amp_backend(self) -> Optional['AMPType']:
if isinstance(self.precision_plugin, ApexMixedPrecisionPlugin):
return AMPType.APEX
elif isinstance(self.precision_plugin, NativeMixedPrecisionPlugin):
Expand Down Expand Up @@ -464,7 +466,7 @@ def model_sharded_context(self) -> Generator[None, None, None]:
yield

# todo: remove in v1.5
def connect_training_type_plugin(self, plugin: TrainingTypePlugin, model: LightningModule) -> None:
def connect_training_type_plugin(self, plugin: TrainingTypePlugin, model: 'LightningModule') -> None:
"""
Attaches the training type plugin to the accelerator.
Also transfers ownership of the model to this plugin
Expand Down
9 changes: 7 additions & 2 deletions pytorch_lightning/accelerators/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytorch_lightning as pl
from typing import TYPE_CHECKING

from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.plugins.precision import MixedPrecisionPlugin
from pytorch_lightning.utilities.exceptions import MisconfigurationException

if TYPE_CHECKING:
from pytorch_lightning.core import LightningModule
from pytorch_lightning.trainer.trainer import Trainer


class CPUAccelerator(Accelerator):
""" Accelerator for CPU devices. """

def setup(self, trainer: 'pl.Trainer', model: 'pl.LightningModule') -> None:
def setup(self, trainer: 'Trainer', model: 'LightningModule') -> None:
"""
Raises:
MisconfigurationException:
Expand Down
9 changes: 6 additions & 3 deletions pytorch_lightning/accelerators/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,25 @@
# limitations under the License.
import logging
import os
from typing import Any
from typing import Any, TYPE_CHECKING

import torch

import pytorch_lightning as pl
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.plugins import DataParallelPlugin
from pytorch_lightning.utilities.exceptions import MisconfigurationException

_log = logging.getLogger(__name__)

if TYPE_CHECKING:
from pytorch_lightning.core import LightningModule
from pytorch_lightning.trainer.trainer import Trainer


class GPUAccelerator(Accelerator):
""" Accelerator for GPU devices. """

def setup(self, trainer: 'pl.Trainer', model: 'pl.LightningModule') -> None:
def setup(self, trainer: 'Trainer', model: 'LightningModule') -> None:
"""
Raises:
MisconfigurationException:
Expand Down
34 changes: 21 additions & 13 deletions pytorch_lightning/accelerators/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Union
from typing import Any, Callable, TYPE_CHECKING, Union

from torch.optim import Optimizer

Expand All @@ -28,13 +28,17 @@

xla_clip_grad_norm_ = clip_grad_norm_

import pytorch_lightning as pl
if TYPE_CHECKING:
from torch.optim import Optimizer

from pytorch_lightning.core import LightningModule
from pytorch_lightning.trainer.trainer import Trainer


class TPUAccelerator(Accelerator):
""" Accelerator for TPU devices. """

def setup(self, trainer: 'pl.Trainer', model: 'pl.LightningModule') -> None:
def setup(self, trainer: 'Trainer', model: 'LightningModule') -> None:
"""
Raises:
MisconfigurationException:
Expand All @@ -51,19 +55,11 @@ def setup(self, trainer: 'pl.Trainer', model: 'pl.LightningModule') -> None:
return super().setup(trainer, model)

def run_optimizer_step(
self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs: Any
self, optimizer: 'Optimizer', optimizer_idx: int, lambda_closure: Callable, **kwargs: Any
) -> None:
xm.optimizer_step(optimizer, barrier=False, optimizer_args={'closure': lambda_closure, **kwargs})

def clip_gradients(
self,
optimizer: Optimizer,
clip_val: Union[float, int],
norm_type: float = 2.0,
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM
) -> None:
assert gradient_clip_algorithm is GradClipAlgorithmType.NORM, \
"Only NORM gradient clipping is supported on TPU for now"
def _clip_gradients_norm(self, clip_val: Union[float, int], norm_type: float = 2.0) -> None:

model = self.lightning_module
parameters = model.parameters()
Expand All @@ -75,3 +71,15 @@ def clip_gradients(
max_norm = grad_clip_val

xla_clip_grad_norm_(parameters, max_norm, norm_type)

def clip_gradients(
self,
optimizer: Optimizer,
clip_val: Union[int, float],
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
) -> None:

if gradient_clip_algorithm == GradClipAlgorithmType.NORM:
return self._clip_gradients_norm(clip_val=clip_val)

raise NotImplementedError("Only NORM gradient clipping is supported on TPU for now")
Loading