From 78c33ce139981d9f42ebe55ed2681c361c7bccbf Mon Sep 17 00:00:00 2001 From: donlapark Date: Sun, 3 Jul 2022 01:00:07 +0700 Subject: [PATCH 01/13] fix typing in pytorch_lightning/tuner/lr_finder.py --- pyproject.toml | 1 - src/pytorch_lightning/tuner/lr_finder.py | 42 ++++++++++++------------ 2 files changed, 21 insertions(+), 22 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 51781d4953935..1e2c59d28a3db 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -89,7 +89,6 @@ module = [ "pytorch_lightning.trainer.supporters", "pytorch_lightning.trainer.trainer", "pytorch_lightning.tuner.batch_size_scaling", - "pytorch_lightning.tuner.lr_finder", "pytorch_lightning.tuner.tuning", "pytorch_lightning.utilities.auto_restart", "pytorch_lightning.utilities.data", diff --git a/src/pytorch_lightning/tuner/lr_finder.py b/src/pytorch_lightning/tuner/lr_finder.py index ad15707d079f1..88bdb1a7b4868 100644 --- a/src/pytorch_lightning/tuner/lr_finder.py +++ b/src/pytorch_lightning/tuner/lr_finder.py @@ -16,9 +16,10 @@ import os import uuid from functools import wraps -from typing import Any, Dict, Optional, Sequence +from typing import Any, List, Dict, Optional, Union, Sequence, Callable import numpy as np +import matplotlib.pyplot as plt import torch from torch.optim.lr_scheduler import _LRScheduler @@ -29,7 +30,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.parsing import lightning_hasattr, lightning_setattr from pytorch_lightning.utilities.rank_zero import rank_zero_warn -from pytorch_lightning.utilities.types import LRSchedulerConfig +from pytorch_lightning.utilities.types import LRSchedulerConfig, LRSchedulerTypeUnion, STEP_OUTPUT # check if ipywidgets is installed before importing tqdm.auto # to ensure it won't fail and a progress bar is displayed @@ -95,16 +96,16 @@ def __init__(self, mode: str, lr_min: float, lr_max: float, num_training: int): self.lr_max = lr_max self.num_training = num_training - self.results = {} + self.results = {} # type: Dict[str, Any] self._total_batch_idx = 0 # for debug purpose - def _exchange_scheduler(self, trainer: "pl.Trainer", model: "pl.LightningModule"): + def _exchange_scheduler(self, trainer: "pl.Trainer", model: "pl.LightningModule") -> Callable[["pl.Trainer"], None]: """Decorate `trainer.strategy.setup_optimizers` method such that it sets the user's originally specified optimizer together with a new scheduler that takes care of the learning rate search.""" setup_optimizers = trainer.strategy.setup_optimizers @wraps(setup_optimizers) - def func(trainer): + def func(trainer: "pl.Trainer") -> None: # Decide the structure of the output from _init_optimizers_and_lr_schedulers optimizers, _, _ = _init_optimizers_and_lr_schedulers(trainer.lightning_module) @@ -122,23 +123,22 @@ def func(trainer): param_group["initial_lr"] = new_lr args = (optimizer, self.lr_max, self.num_training) - scheduler = _LinearLR(*args) if self.mode == "linear" else _ExponentialLR(*args) + scheduler = _LinearLR(*args) if self.mode == "linear" else _ExponentialLR(*args) # type: _LRScheduler trainer.strategy.optimizers = [optimizer] - trainer.strategy.lr_scheduler_configs = [LRSchedulerConfig(scheduler, interval="step", opt_idx=0)] + trainer.strategy.lr_scheduler_configs = [LRSchedulerConfig(scheduler, interval="step", opt_idx=0)] # type: ignore trainer.strategy.optimizer_frequencies = [] _set_scheduler_opt_idx(trainer.optimizers, trainer.lr_scheduler_configs) return func - def plot(self, suggest: bool = False, show: bool = False): + def plot(self, suggest: bool = False, show: bool = False) -> plt.Figure: """Plot results from lr_find run Args: suggest: if True, will mark suggested lr to use with a red point show: if True, will show figure """ - import matplotlib.pyplot as plt lrs = self.results["lr"] losses = self.results["loss"] @@ -162,7 +162,7 @@ def plot(self, suggest: bool = False, show: bool = False): return fig - def suggestion(self, skip_begin: int = 10, skip_end: int = 1): + def suggestion(self, skip_begin: int = 10, skip_end: int = 1) -> Optional[float]: """This will propose a suggestion for choice of initial learning rate as the point with the steepest negative gradient. @@ -196,7 +196,7 @@ def lr_find( """See :meth:`~pytorch_lightning.tuner.tuning.Tuner.lr_find`""" if trainer.fast_dev_run: rank_zero_warn("Skipping learning rate finder since fast_dev_run is enabled.") - return + return # type: ignore # Determine lr attr if update_attr: @@ -218,7 +218,7 @@ def lr_find( trainer.progress_bar_callback.disable() # Configure optimizer and scheduler - trainer.strategy.setup_optimizers = lr_finder._exchange_scheduler(trainer, model) + trainer.strategy.setup_optimizers = lr_finder._exchange_scheduler(trainer, model) # type: ignore # Fit, lr & loss logged in callback trainer.tuner._run(model) @@ -304,14 +304,14 @@ def __init__( self.num_training = num_training self.early_stop_threshold = early_stop_threshold self.beta = beta - self.losses = [] - self.lrs = [] + self.losses: List[float] = [] + self.lrs: List[float] = [] self.avg_loss = 0.0 self.best_loss = 0.0 self.progress_bar_refresh_rate = progress_bar_refresh_rate self.progress_bar = None - def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): + def on_train_batch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int) -> None: """Called before each training batch, logs the lr that will be used.""" if (trainer.fit_loop.batch_idx + 1) % trainer.accumulate_grad_batches != 0: return @@ -319,9 +319,9 @@ def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): if self.progress_bar_refresh_rate and self.progress_bar is None: self.progress_bar = tqdm(desc="Finding best initial lr", total=self.num_training) - self.lrs.append(trainer.lr_scheduler_configs[0].scheduler.lr[0]) + self.lrs.append(trainer.lr_scheduler_configs[0].scheduler.lr[0]) # type: ignore - def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + def on_train_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int) -> None: """Called when the training batch ends, logs the calculated loss.""" if (trainer.fit_loop.batch_idx + 1) % trainer.accumulate_grad_batches != 0: return @@ -372,7 +372,7 @@ def __init__(self, optimizer: torch.optim.Optimizer, end_lr: float, num_iter: in self.num_iter = num_iter super().__init__(optimizer, last_epoch) - def get_lr(self): + def get_lr(self) -> List[float]: # type: ignore[override] curr_iter = self.last_epoch + 1 r = curr_iter / self.num_iter @@ -384,7 +384,7 @@ def get_lr(self): return val @property - def lr(self): + def lr(self) -> Union[float, List[float]]: return self._lr @@ -410,7 +410,7 @@ def __init__(self, optimizer: torch.optim.Optimizer, end_lr: float, num_iter: in self.num_iter = num_iter super().__init__(optimizer, last_epoch) - def get_lr(self): + def get_lr(self) -> List[float]: # type: ignore[override] curr_iter = self.last_epoch + 1 r = curr_iter / self.num_iter @@ -422,5 +422,5 @@ def get_lr(self): return val @property - def lr(self): + def lr(self) -> Union[float, List[float]]: return self._lr From c3601e9a03cfbe64a21181bee039a0d354e344ad Mon Sep 17 00:00:00 2001 From: donlapark Date: Sun, 3 Jul 2022 01:05:07 +0700 Subject: [PATCH 02/13] fix typing in pytorch_lightning/tuner/lr_finder.py --- src/pytorch_lightning/tuner/lr_finder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/tuner/lr_finder.py b/src/pytorch_lightning/tuner/lr_finder.py index 88bdb1a7b4868..d47dcf89a3a69 100644 --- a/src/pytorch_lightning/tuner/lr_finder.py +++ b/src/pytorch_lightning/tuner/lr_finder.py @@ -30,7 +30,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.parsing import lightning_hasattr, lightning_setattr from pytorch_lightning.utilities.rank_zero import rank_zero_warn -from pytorch_lightning.utilities.types import LRSchedulerConfig, LRSchedulerTypeUnion, STEP_OUTPUT +from pytorch_lightning.utilities.types import LRSchedulerConfig, STEP_OUTPUT # check if ipywidgets is installed before importing tqdm.auto # to ensure it won't fail and a progress bar is displayed From 96da1e7fb5a760db6dbd728bc4dd4bafb48aab36 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 2 Jul 2022 19:05:54 +0000 Subject: [PATCH 03/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pytorch_lightning/tuner/lr_finder.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/pytorch_lightning/tuner/lr_finder.py b/src/pytorch_lightning/tuner/lr_finder.py index d47dcf89a3a69..6bea86ab7cd99 100644 --- a/src/pytorch_lightning/tuner/lr_finder.py +++ b/src/pytorch_lightning/tuner/lr_finder.py @@ -16,10 +16,10 @@ import os import uuid from functools import wraps -from typing import Any, List, Dict, Optional, Union, Sequence, Callable +from typing import Any, Callable, Dict, List, Optional, Sequence, Union -import numpy as np import matplotlib.pyplot as plt +import numpy as np import torch from torch.optim.lr_scheduler import _LRScheduler @@ -311,7 +311,9 @@ def __init__( self.progress_bar_refresh_rate = progress_bar_refresh_rate self.progress_bar = None - def on_train_batch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int) -> None: + def on_train_batch_start( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int + ) -> None: """Called before each training batch, logs the lr that will be used.""" if (trainer.fit_loop.batch_idx + 1) % trainer.accumulate_grad_batches != 0: return @@ -321,7 +323,9 @@ def on_train_batch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo self.lrs.append(trainer.lr_scheduler_configs[0].scheduler.lr[0]) # type: ignore - def on_train_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int) -> None: + def on_train_batch_end( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int + ) -> None: """Called when the training batch ends, logs the calculated loss.""" if (trainer.fit_loop.batch_idx + 1) % trainer.accumulate_grad_batches != 0: return From 7645add6771a35f87e817fcc3d41a9a5e86e4e7f Mon Sep 17 00:00:00 2001 From: donlapark <10988155+donlapark@users.noreply.github.com> Date: Sun, 3 Jul 2022 13:33:45 +0700 Subject: [PATCH 04/13] Add a check before importing matplotlib --- src/pytorch_lightning/tuner/lr_finder.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/tuner/lr_finder.py b/src/pytorch_lightning/tuner/lr_finder.py index 6bea86ab7cd99..7f740cd61ef8d 100644 --- a/src/pytorch_lightning/tuner/lr_finder.py +++ b/src/pytorch_lightning/tuner/lr_finder.py @@ -28,6 +28,7 @@ from pytorch_lightning.core.optimizer import _init_optimizers_and_lr_schedulers, _set_scheduler_opt_idx from pytorch_lightning.loggers.logger import DummyLogger from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _RequirementAvailable from pytorch_lightning.utilities.parsing import lightning_hasattr, lightning_setattr from pytorch_lightning.utilities.rank_zero import rank_zero_warn from pytorch_lightning.utilities.types import LRSchedulerConfig, STEP_OUTPUT @@ -39,6 +40,10 @@ else: from tqdm import tqdm +_MATPLOTLIB_AVAILABLE: bool = _RequirementAvailable("matplotlib") +if _MATPLOTLIB_AVAILABLE: + import matplotlib.pyplot as plt + log = logging.getLogger(__name__) @@ -132,14 +137,19 @@ def func(trainer: "pl.Trainer") -> None: return func - def plot(self, suggest: bool = False, show: bool = False) -> plt.Figure: + def plot(self, suggest: bool = False, show: bool = False) -> Optional[plt.Figure]: """Plot results from lr_find run Args: suggest: if True, will mark suggested lr to use with a red point show: if True, will show figure """ - + if not _MATPLOTLIB_AVAILABLE: + raise MisconfigurationException( + "To use the `plot` method, you must have Matplotlib installed." + " Install it by running `pip install -U matplotlib`." + ) + lrs = self.results["lr"] losses = self.results["loss"] From 14e9c7aac3e3144ab686779b6d59e9018416fe8a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 3 Jul 2022 06:35:25 +0000 Subject: [PATCH 05/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pytorch_lightning/tuner/lr_finder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/tuner/lr_finder.py b/src/pytorch_lightning/tuner/lr_finder.py index 7f740cd61ef8d..c4dd0b3b79798 100644 --- a/src/pytorch_lightning/tuner/lr_finder.py +++ b/src/pytorch_lightning/tuner/lr_finder.py @@ -149,7 +149,7 @@ def plot(self, suggest: bool = False, show: bool = False) -> Optional[plt.Figure "To use the `plot` method, you must have Matplotlib installed." " Install it by running `pip install -U matplotlib`." ) - + lrs = self.results["lr"] losses = self.results["loss"] From 18d56106737781583633cfb30afd96d693ef18da Mon Sep 17 00:00:00 2001 From: donlapark <10988155+donlapark@users.noreply.github.com> Date: Sun, 3 Jul 2022 13:47:03 +0700 Subject: [PATCH 06/13] Update src/pytorch_lightning/tuner/lr_finder.py Co-authored-by: Akihiro Nitta --- src/pytorch_lightning/tuner/lr_finder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/tuner/lr_finder.py b/src/pytorch_lightning/tuner/lr_finder.py index c4dd0b3b79798..e17c0988744e7 100644 --- a/src/pytorch_lightning/tuner/lr_finder.py +++ b/src/pytorch_lightning/tuner/lr_finder.py @@ -101,7 +101,7 @@ def __init__(self, mode: str, lr_min: float, lr_max: float, num_training: int): self.lr_max = lr_max self.num_training = num_training - self.results = {} # type: Dict[str, Any] + self.results: Dict[str, Any] = {} self._total_batch_idx = 0 # for debug purpose def _exchange_scheduler(self, trainer: "pl.Trainer", model: "pl.LightningModule") -> Callable[["pl.Trainer"], None]: From 84bad3ada421806f50b2cc84cf9c3c5c286ed40f Mon Sep 17 00:00:00 2001 From: donlapark <10988155+donlapark@users.noreply.github.com> Date: Sun, 3 Jul 2022 13:49:19 +0700 Subject: [PATCH 07/13] Update src/pytorch_lightning/tuner/lr_finder.py Co-authored-by: Akihiro Nitta --- src/pytorch_lightning/tuner/lr_finder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/tuner/lr_finder.py b/src/pytorch_lightning/tuner/lr_finder.py index e17c0988744e7..c0f1da17ccc97 100644 --- a/src/pytorch_lightning/tuner/lr_finder.py +++ b/src/pytorch_lightning/tuner/lr_finder.py @@ -206,7 +206,7 @@ def lr_find( """See :meth:`~pytorch_lightning.tuner.tuning.Tuner.lr_find`""" if trainer.fast_dev_run: rank_zero_warn("Skipping learning rate finder since fast_dev_run is enabled.") - return # type: ignore + return None # Determine lr attr if update_attr: From 4d6068cd1300e4a991d36352d8111f175d9568ac Mon Sep 17 00:00:00 2001 From: donlapark <10988155+donlapark@users.noreply.github.com> Date: Sun, 3 Jul 2022 14:02:17 +0700 Subject: [PATCH 08/13] Add a check before importing matplotlib - fixed --- src/pytorch_lightning/tuner/lr_finder.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/pytorch_lightning/tuner/lr_finder.py b/src/pytorch_lightning/tuner/lr_finder.py index c0f1da17ccc97..782b2252f0ee6 100644 --- a/src/pytorch_lightning/tuner/lr_finder.py +++ b/src/pytorch_lightning/tuner/lr_finder.py @@ -18,7 +18,6 @@ from functools import wraps from typing import Any, Callable, Dict, List, Optional, Sequence, Union -import matplotlib.pyplot as plt import numpy as np import torch from torch.optim.lr_scheduler import _LRScheduler From 4374a07444146246fa07a5e786a05d7dc1386b60 Mon Sep 17 00:00:00 2001 From: donlapark Date: Sun, 3 Jul 2022 23:49:29 +0700 Subject: [PATCH 09/13] cast _LinearLR and _ExponentialLR as pl's _LRScheduler type --- src/pytorch_lightning/tuner/lr_finder.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/pytorch_lightning/tuner/lr_finder.py b/src/pytorch_lightning/tuner/lr_finder.py index 782b2252f0ee6..67976f3b51b67 100644 --- a/src/pytorch_lightning/tuner/lr_finder.py +++ b/src/pytorch_lightning/tuner/lr_finder.py @@ -16,7 +16,7 @@ import os import uuid from functools import wraps -from typing import Any, Callable, Dict, List, Optional, Sequence, Union +from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Union import numpy as np import torch @@ -39,7 +39,7 @@ else: from tqdm import tqdm -_MATPLOTLIB_AVAILABLE: bool = _RequirementAvailable("matplotlib") +_MATPLOTLIB_AVAILABLE: bool = _RequirementAvailable("matplotlib") # type: ignore[assignment] if _MATPLOTLIB_AVAILABLE: import matplotlib.pyplot as plt @@ -127,16 +127,17 @@ def func(trainer: "pl.Trainer") -> None: param_group["initial_lr"] = new_lr args = (optimizer, self.lr_max, self.num_training) - scheduler = _LinearLR(*args) if self.mode == "linear" else _ExponentialLR(*args) # type: _LRScheduler + scheduler = _LinearLR(*args) if self.mode == "linear" else _ExponentialLR(*args) + scheduler = cast(pl.utilities.types._LRScheduler, scheduler) trainer.strategy.optimizers = [optimizer] - trainer.strategy.lr_scheduler_configs = [LRSchedulerConfig(scheduler, interval="step", opt_idx=0)] # type: ignore + trainer.strategy.lr_scheduler_configs = [LRSchedulerConfig(scheduler, interval="step", opt_idx=0)] trainer.strategy.optimizer_frequencies = [] _set_scheduler_opt_idx(trainer.optimizers, trainer.lr_scheduler_configs) return func - def plot(self, suggest: bool = False, show: bool = False) -> Optional[plt.Figure]: + def plot(self, suggest: bool = False, show: bool = False) -> Optional["plt.Figure"]: """Plot results from lr_find run Args: suggest: if True, will mark suggested lr to use with a red point @@ -227,7 +228,7 @@ def lr_find( trainer.progress_bar_callback.disable() # Configure optimizer and scheduler - trainer.strategy.setup_optimizers = lr_finder._exchange_scheduler(trainer, model) # type: ignore + trainer.strategy.setup_optimizers = lr_finder._exchange_scheduler(trainer, model) # type: ignore[assignment] # Fit, lr & loss logged in callback trainer.tuner._run(model) @@ -330,7 +331,7 @@ def on_train_batch_start( if self.progress_bar_refresh_rate and self.progress_bar is None: self.progress_bar = tqdm(desc="Finding best initial lr", total=self.num_training) - self.lrs.append(trainer.lr_scheduler_configs[0].scheduler.lr[0]) # type: ignore + self.lrs.append(trainer.lr_scheduler_configs[0].scheduler.lr[0]) # type: ignore[union-attr] def on_train_batch_end( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int From 8c65e1f1c1035bb0b356ff52255664a6b8e7e583 Mon Sep 17 00:00:00 2001 From: donlapark <10988155+donlapark@users.noreply.github.com> Date: Mon, 4 Jul 2022 11:34:02 +0000 Subject: [PATCH 10/13] Add `TYPE_CHECKING` Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> --- src/pytorch_lightning/tuner/lr_finder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/tuner/lr_finder.py b/src/pytorch_lightning/tuner/lr_finder.py index 67976f3b51b67..81490bacaaa89 100644 --- a/src/pytorch_lightning/tuner/lr_finder.py +++ b/src/pytorch_lightning/tuner/lr_finder.py @@ -16,7 +16,7 @@ import os import uuid from functools import wraps -from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Union +from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Union, TYPE_CHECKING import numpy as np import torch From f42a4391e530319396b2a11ed43c0ca47b4e81d3 Mon Sep 17 00:00:00 2001 From: donlapark <10988155+donlapark@users.noreply.github.com> Date: Mon, 4 Jul 2022 18:34:50 +0700 Subject: [PATCH 11/13] Add `TYPE_CHECKING` Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> --- src/pytorch_lightning/tuner/lr_finder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/tuner/lr_finder.py b/src/pytorch_lightning/tuner/lr_finder.py index 81490bacaaa89..785d3d73a8980 100644 --- a/src/pytorch_lightning/tuner/lr_finder.py +++ b/src/pytorch_lightning/tuner/lr_finder.py @@ -40,7 +40,7 @@ from tqdm import tqdm _MATPLOTLIB_AVAILABLE: bool = _RequirementAvailable("matplotlib") # type: ignore[assignment] -if _MATPLOTLIB_AVAILABLE: +if _MATPLOTLIB_AVAILABLE and TYPE_CHECKING: import matplotlib.pyplot as plt log = logging.getLogger(__name__) From c6a28ffb3fd2bc54d405e18ff4b503e70a957a69 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 4 Jul 2022 11:35:34 +0000 Subject: [PATCH 12/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pytorch_lightning/tuner/lr_finder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/tuner/lr_finder.py b/src/pytorch_lightning/tuner/lr_finder.py index 785d3d73a8980..f375ec6dcf8a4 100644 --- a/src/pytorch_lightning/tuner/lr_finder.py +++ b/src/pytorch_lightning/tuner/lr_finder.py @@ -16,7 +16,7 @@ import os import uuid from functools import wraps -from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Union, TYPE_CHECKING +from typing import Any, Callable, cast, Dict, List, Optional, Sequence, TYPE_CHECKING, Union import numpy as np import torch From 6783e4eaa26276d1ea9700456ea3a7b068a63bbf Mon Sep 17 00:00:00 2001 From: donlapark <10988155+donlapark@users.noreply.github.com> Date: Mon, 4 Jul 2022 18:37:02 +0700 Subject: [PATCH 13/13] Revert matplotlib import Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> --- src/pytorch_lightning/tuner/lr_finder.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/pytorch_lightning/tuner/lr_finder.py b/src/pytorch_lightning/tuner/lr_finder.py index f375ec6dcf8a4..9e584708f369e 100644 --- a/src/pytorch_lightning/tuner/lr_finder.py +++ b/src/pytorch_lightning/tuner/lr_finder.py @@ -149,6 +149,7 @@ def plot(self, suggest: bool = False, show: bool = False) -> Optional["plt.Figur "To use the `plot` method, you must have Matplotlib installed." " Install it by running `pip install -U matplotlib`." ) + import matplotlib.pyplot as plt lrs = self.results["lr"] losses = self.results["loss"]