diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index d652c7fbb35d7..3decad4c79353 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -18,6 +18,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support to upgrade all checkpoints in a folder using the `pl.utilities.upgrade_checkpoint` script ([#15333](https://github.com/Lightning-AI/lightning/pull/15333)) +- Add an axes argument `ax` to the `.lr_find().plot()` to enable writing to a user-defined axes in a matplotlib figure ([#15652](https://github.com/Lightning-AI/lightning/pull/15652)) + + - Added a check to validate that wrapped FSDP models are used while initializing optimizers ([#15301](https://github.com/Lightning-AI/lightning/pull/15301)) diff --git a/src/pytorch_lightning/tuner/lr_finder.py b/src/pytorch_lightning/tuner/lr_finder.py index 63d7c09abb26e..29a5d47776a9e 100644 --- a/src/pytorch_lightning/tuner/lr_finder.py +++ b/src/pytorch_lightning/tuner/lr_finder.py @@ -38,9 +38,9 @@ from tqdm import tqdm _MATPLOTLIB_AVAILABLE = RequirementCache("matplotlib") -if _MATPLOTLIB_AVAILABLE and TYPE_CHECKING: +if TYPE_CHECKING and _MATPLOTLIB_AVAILABLE: import matplotlib.pyplot as plt - + from matplotlib.axes import Axes log = logging.getLogger(__name__) @@ -130,12 +130,14 @@ def _exchange_scheduler(self, trainer: "pl.Trainer") -> None: trainer.strategy.lr_scheduler_configs = [LRSchedulerConfig(scheduler, interval="step", opt_idx=0)] _set_scheduler_opt_idx(trainer.optimizers, trainer.lr_scheduler_configs) - def plot(self, suggest: bool = False, show: bool = False) -> Optional["plt.Figure"]: + def plot(self, suggest: bool = False, show: bool = False, ax: Optional["Axes"] = None) -> Optional["plt.Figure"]: """Plot results from lr_find run Args: suggest: if True, will mark suggested lr to use with a red point show: if True, will show figure + + ax: Axes object to which the plot is to be drawn. If not provided, a new figure is created. """ if not _MATPLOTLIB_AVAILABLE: raise MisconfigurationException( @@ -147,7 +149,10 @@ def plot(self, suggest: bool = False, show: bool = False) -> Optional["plt.Figur lrs = self.results["lr"] losses = self.results["loss"] - fig, ax = plt.subplots() + if ax is None: + fig, ax = plt.subplots() + else: + fig = ax.figure # Plot loss as a function of the learning rate ax.plot(lrs, losses)