diff --git a/pyproject.toml b/pyproject.toml index 51781d4953935..1e2c59d28a3db 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -89,7 +89,6 @@ module = [ "pytorch_lightning.trainer.supporters", "pytorch_lightning.trainer.trainer", "pytorch_lightning.tuner.batch_size_scaling", - "pytorch_lightning.tuner.lr_finder", "pytorch_lightning.tuner.tuning", "pytorch_lightning.utilities.auto_restart", "pytorch_lightning.utilities.data", diff --git a/src/pytorch_lightning/tuner/lr_finder.py b/src/pytorch_lightning/tuner/lr_finder.py index ad15707d079f1..9e584708f369e 100644 --- a/src/pytorch_lightning/tuner/lr_finder.py +++ b/src/pytorch_lightning/tuner/lr_finder.py @@ -16,7 +16,7 @@ import os import uuid from functools import wraps -from typing import Any, Dict, Optional, Sequence +from typing import Any, Callable, cast, Dict, List, Optional, Sequence, TYPE_CHECKING, Union import numpy as np import torch @@ -27,9 +27,10 @@ from pytorch_lightning.core.optimizer import _init_optimizers_and_lr_schedulers, _set_scheduler_opt_idx from pytorch_lightning.loggers.logger import DummyLogger from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _RequirementAvailable from pytorch_lightning.utilities.parsing import lightning_hasattr, lightning_setattr from pytorch_lightning.utilities.rank_zero import rank_zero_warn -from pytorch_lightning.utilities.types import LRSchedulerConfig +from pytorch_lightning.utilities.types import LRSchedulerConfig, STEP_OUTPUT # check if ipywidgets is installed before importing tqdm.auto # to ensure it won't fail and a progress bar is displayed @@ -38,6 +39,10 @@ else: from tqdm import tqdm +_MATPLOTLIB_AVAILABLE: bool = _RequirementAvailable("matplotlib") # type: ignore[assignment] +if _MATPLOTLIB_AVAILABLE and TYPE_CHECKING: + import matplotlib.pyplot as plt + log = logging.getLogger(__name__) @@ -95,16 +100,16 @@ def __init__(self, mode: str, lr_min: float, lr_max: float, num_training: int): self.lr_max = lr_max self.num_training = num_training - self.results = {} + self.results: Dict[str, Any] = {} self._total_batch_idx = 0 # for debug purpose - def _exchange_scheduler(self, trainer: "pl.Trainer", model: "pl.LightningModule"): + def _exchange_scheduler(self, trainer: "pl.Trainer", model: "pl.LightningModule") -> Callable[["pl.Trainer"], None]: """Decorate `trainer.strategy.setup_optimizers` method such that it sets the user's originally specified optimizer together with a new scheduler that takes care of the learning rate search.""" setup_optimizers = trainer.strategy.setup_optimizers @wraps(setup_optimizers) - def func(trainer): + def func(trainer: "pl.Trainer") -> None: # Decide the structure of the output from _init_optimizers_and_lr_schedulers optimizers, _, _ = _init_optimizers_and_lr_schedulers(trainer.lightning_module) @@ -123,6 +128,7 @@ def func(trainer): args = (optimizer, self.lr_max, self.num_training) scheduler = _LinearLR(*args) if self.mode == "linear" else _ExponentialLR(*args) + scheduler = cast(pl.utilities.types._LRScheduler, scheduler) trainer.strategy.optimizers = [optimizer] trainer.strategy.lr_scheduler_configs = [LRSchedulerConfig(scheduler, interval="step", opt_idx=0)] @@ -131,13 +137,18 @@ def func(trainer): return func - def plot(self, suggest: bool = False, show: bool = False): + def plot(self, suggest: bool = False, show: bool = False) -> Optional["plt.Figure"]: """Plot results from lr_find run Args: suggest: if True, will mark suggested lr to use with a red point show: if True, will show figure """ + if not _MATPLOTLIB_AVAILABLE: + raise MisconfigurationException( + "To use the `plot` method, you must have Matplotlib installed." + " Install it by running `pip install -U matplotlib`." + ) import matplotlib.pyplot as plt lrs = self.results["lr"] @@ -162,7 +173,7 @@ def plot(self, suggest: bool = False, show: bool = False): return fig - def suggestion(self, skip_begin: int = 10, skip_end: int = 1): + def suggestion(self, skip_begin: int = 10, skip_end: int = 1) -> Optional[float]: """This will propose a suggestion for choice of initial learning rate as the point with the steepest negative gradient. @@ -196,7 +207,7 @@ def lr_find( """See :meth:`~pytorch_lightning.tuner.tuning.Tuner.lr_find`""" if trainer.fast_dev_run: rank_zero_warn("Skipping learning rate finder since fast_dev_run is enabled.") - return + return None # Determine lr attr if update_attr: @@ -218,7 +229,7 @@ def lr_find( trainer.progress_bar_callback.disable() # Configure optimizer and scheduler - trainer.strategy.setup_optimizers = lr_finder._exchange_scheduler(trainer, model) + trainer.strategy.setup_optimizers = lr_finder._exchange_scheduler(trainer, model) # type: ignore[assignment] # Fit, lr & loss logged in callback trainer.tuner._run(model) @@ -304,14 +315,16 @@ def __init__( self.num_training = num_training self.early_stop_threshold = early_stop_threshold self.beta = beta - self.losses = [] - self.lrs = [] + self.losses: List[float] = [] + self.lrs: List[float] = [] self.avg_loss = 0.0 self.best_loss = 0.0 self.progress_bar_refresh_rate = progress_bar_refresh_rate self.progress_bar = None - def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): + def on_train_batch_start( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int + ) -> None: """Called before each training batch, logs the lr that will be used.""" if (trainer.fit_loop.batch_idx + 1) % trainer.accumulate_grad_batches != 0: return @@ -319,9 +332,11 @@ def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): if self.progress_bar_refresh_rate and self.progress_bar is None: self.progress_bar = tqdm(desc="Finding best initial lr", total=self.num_training) - self.lrs.append(trainer.lr_scheduler_configs[0].scheduler.lr[0]) + self.lrs.append(trainer.lr_scheduler_configs[0].scheduler.lr[0]) # type: ignore[union-attr] - def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + def on_train_batch_end( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int + ) -> None: """Called when the training batch ends, logs the calculated loss.""" if (trainer.fit_loop.batch_idx + 1) % trainer.accumulate_grad_batches != 0: return @@ -372,7 +387,7 @@ def __init__(self, optimizer: torch.optim.Optimizer, end_lr: float, num_iter: in self.num_iter = num_iter super().__init__(optimizer, last_epoch) - def get_lr(self): + def get_lr(self) -> List[float]: # type: ignore[override] curr_iter = self.last_epoch + 1 r = curr_iter / self.num_iter @@ -384,7 +399,7 @@ def get_lr(self): return val @property - def lr(self): + def lr(self) -> Union[float, List[float]]: return self._lr @@ -410,7 +425,7 @@ def __init__(self, optimizer: torch.optim.Optimizer, end_lr: float, num_iter: in self.num_iter = num_iter super().__init__(optimizer, last_epoch) - def get_lr(self): + def get_lr(self) -> List[float]: # type: ignore[override] curr_iter = self.last_epoch + 1 r = curr_iter / self.num_iter @@ -422,5 +437,5 @@ def get_lr(self): return val @property - def lr(self): + def lr(self) -> Union[float, List[float]]: return self._lr