Skip to content

fix mypy typing errors in pytorch_lightning/tuner/lr_finder.py #13513

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jul 8, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ module = [
"pytorch_lightning.trainer.supporters",
"pytorch_lightning.trainer.trainer",
"pytorch_lightning.tuner.batch_size_scaling",
"pytorch_lightning.tuner.lr_finder",
"pytorch_lightning.tuner.tuning",
"pytorch_lightning.utilities.auto_restart",
"pytorch_lightning.utilities.data",
Expand Down
46 changes: 25 additions & 21 deletions src/pytorch_lightning/tuner/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
import os
import uuid
from functools import wraps
from typing import Any, Dict, Optional, Sequence
from typing import Any, Callable, Dict, List, Optional, Sequence, Union

import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.optim.lr_scheduler import _LRScheduler
Expand All @@ -29,7 +30,7 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.parsing import lightning_hasattr, lightning_setattr
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
from pytorch_lightning.utilities.types import LRSchedulerConfig
from pytorch_lightning.utilities.types import LRSchedulerConfig, STEP_OUTPUT

# check if ipywidgets is installed before importing tqdm.auto
# to ensure it won't fail and a progress bar is displayed
Expand Down Expand Up @@ -95,16 +96,16 @@ def __init__(self, mode: str, lr_min: float, lr_max: float, num_training: int):
self.lr_max = lr_max
self.num_training = num_training

self.results = {}
self.results = {} # type: Dict[str, Any]
self._total_batch_idx = 0 # for debug purpose

def _exchange_scheduler(self, trainer: "pl.Trainer", model: "pl.LightningModule"):
def _exchange_scheduler(self, trainer: "pl.Trainer", model: "pl.LightningModule") -> Callable[["pl.Trainer"], None]:
"""Decorate `trainer.strategy.setup_optimizers` method such that it sets the user's originally specified
optimizer together with a new scheduler that takes care of the learning rate search."""
setup_optimizers = trainer.strategy.setup_optimizers

@wraps(setup_optimizers)
def func(trainer):
def func(trainer: "pl.Trainer") -> None:
# Decide the structure of the output from _init_optimizers_and_lr_schedulers
optimizers, _, _ = _init_optimizers_and_lr_schedulers(trainer.lightning_module)

Expand All @@ -122,23 +123,22 @@ def func(trainer):
param_group["initial_lr"] = new_lr

args = (optimizer, self.lr_max, self.num_training)
scheduler = _LinearLR(*args) if self.mode == "linear" else _ExponentialLR(*args)
scheduler = _LinearLR(*args) if self.mode == "linear" else _ExponentialLR(*args) # type: _LRScheduler

trainer.strategy.optimizers = [optimizer]
trainer.strategy.lr_scheduler_configs = [LRSchedulerConfig(scheduler, interval="step", opt_idx=0)]
trainer.strategy.lr_scheduler_configs = [LRSchedulerConfig(scheduler, interval="step", opt_idx=0)] # type: ignore
trainer.strategy.optimizer_frequencies = []
_set_scheduler_opt_idx(trainer.optimizers, trainer.lr_scheduler_configs)

return func

def plot(self, suggest: bool = False, show: bool = False):
def plot(self, suggest: bool = False, show: bool = False) -> plt.Figure:
"""Plot results from lr_find run
Args:
suggest: if True, will mark suggested lr to use with a red point

show: if True, will show figure
"""
import matplotlib.pyplot as plt

lrs = self.results["lr"]
losses = self.results["loss"]
Expand All @@ -162,7 +162,7 @@ def plot(self, suggest: bool = False, show: bool = False):

return fig

def suggestion(self, skip_begin: int = 10, skip_end: int = 1):
def suggestion(self, skip_begin: int = 10, skip_end: int = 1) -> Optional[float]:
"""This will propose a suggestion for choice of initial learning rate as the point with the steepest
negative gradient.

Expand Down Expand Up @@ -196,7 +196,7 @@ def lr_find(
"""See :meth:`~pytorch_lightning.tuner.tuning.Tuner.lr_find`"""
if trainer.fast_dev_run:
rank_zero_warn("Skipping learning rate finder since fast_dev_run is enabled.")
return
return # type: ignore

# Determine lr attr
if update_attr:
Expand All @@ -218,7 +218,7 @@ def lr_find(
trainer.progress_bar_callback.disable()

# Configure optimizer and scheduler
trainer.strategy.setup_optimizers = lr_finder._exchange_scheduler(trainer, model)
trainer.strategy.setup_optimizers = lr_finder._exchange_scheduler(trainer, model) # type: ignore

# Fit, lr & loss logged in callback
trainer.tuner._run(model)
Expand Down Expand Up @@ -304,24 +304,28 @@ def __init__(
self.num_training = num_training
self.early_stop_threshold = early_stop_threshold
self.beta = beta
self.losses = []
self.lrs = []
self.losses: List[float] = []
self.lrs: List[float] = []
self.avg_loss = 0.0
self.best_loss = 0.0
self.progress_bar_refresh_rate = progress_bar_refresh_rate
self.progress_bar = None

def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
def on_train_batch_start(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int
) -> None:
"""Called before each training batch, logs the lr that will be used."""
if (trainer.fit_loop.batch_idx + 1) % trainer.accumulate_grad_batches != 0:
return

if self.progress_bar_refresh_rate and self.progress_bar is None:
self.progress_bar = tqdm(desc="Finding best initial lr", total=self.num_training)

self.lrs.append(trainer.lr_scheduler_configs[0].scheduler.lr[0])
self.lrs.append(trainer.lr_scheduler_configs[0].scheduler.lr[0]) # type: ignore

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
def on_train_batch_end(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int
) -> None:
"""Called when the training batch ends, logs the calculated loss."""
if (trainer.fit_loop.batch_idx + 1) % trainer.accumulate_grad_batches != 0:
return
Expand Down Expand Up @@ -372,7 +376,7 @@ def __init__(self, optimizer: torch.optim.Optimizer, end_lr: float, num_iter: in
self.num_iter = num_iter
super().__init__(optimizer, last_epoch)

def get_lr(self):
def get_lr(self) -> List[float]: # type: ignore[override]
curr_iter = self.last_epoch + 1
r = curr_iter / self.num_iter

Expand All @@ -384,7 +388,7 @@ def get_lr(self):
return val

@property
def lr(self):
def lr(self) -> Union[float, List[float]]:
return self._lr


Expand All @@ -410,7 +414,7 @@ def __init__(self, optimizer: torch.optim.Optimizer, end_lr: float, num_iter: in
self.num_iter = num_iter
super().__init__(optimizer, last_epoch)

def get_lr(self):
def get_lr(self) -> List[float]: # type: ignore[override]
curr_iter = self.last_epoch + 1
r = curr_iter / self.num_iter

Expand All @@ -422,5 +426,5 @@ def get_lr(self):
return val

@property
def lr(self):
def lr(self) -> Union[float, List[float]]:
return self._lr