Skip to content

Commit aab8f48

Browse files
bipinKrishnanawaelchlicarmocca
authored
Add axes argument to lr finder plot (#15652)
Co-authored-by: awaelchli <[email protected]> Co-authored-by: Carlos Mocholí <[email protected]>
1 parent cdb7006 commit aab8f48

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

src/pytorch_lightning/CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1818
- Added support to upgrade all checkpoints in a folder using the `pl.utilities.upgrade_checkpoint` script ([#15333](https://github.com/Lightning-AI/lightning/pull/15333))
1919

2020

21+
- Add an axes argument `ax` to the `.lr_find().plot()` to enable writing to a user-defined axes in a matplotlib figure ([#15652](https://github.com/Lightning-AI/lightning/pull/15652))
22+
23+
2124
- Added a check to validate that wrapped FSDP models are used while initializing optimizers ([#15301](https://github.com/Lightning-AI/lightning/pull/15301))
2225

2326

src/pytorch_lightning/tuner/lr_finder.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@
3838
from tqdm import tqdm
3939

4040
_MATPLOTLIB_AVAILABLE = RequirementCache("matplotlib")
41-
if _MATPLOTLIB_AVAILABLE and TYPE_CHECKING:
41+
if TYPE_CHECKING and _MATPLOTLIB_AVAILABLE:
4242
import matplotlib.pyplot as plt
43-
43+
from matplotlib.axes import Axes
4444
log = logging.getLogger(__name__)
4545

4646

@@ -130,12 +130,14 @@ def _exchange_scheduler(self, trainer: "pl.Trainer") -> None:
130130
trainer.strategy.lr_scheduler_configs = [LRSchedulerConfig(scheduler, interval="step", opt_idx=0)]
131131
_set_scheduler_opt_idx(trainer.optimizers, trainer.lr_scheduler_configs)
132132

133-
def plot(self, suggest: bool = False, show: bool = False) -> Optional["plt.Figure"]:
133+
def plot(self, suggest: bool = False, show: bool = False, ax: Optional["Axes"] = None) -> Optional["plt.Figure"]:
134134
"""Plot results from lr_find run
135135
Args:
136136
suggest: if True, will mark suggested lr to use with a red point
137137
138138
show: if True, will show figure
139+
140+
ax: Axes object to which the plot is to be drawn. If not provided, a new figure is created.
139141
"""
140142
if not _MATPLOTLIB_AVAILABLE:
141143
raise MisconfigurationException(
@@ -147,7 +149,10 @@ def plot(self, suggest: bool = False, show: bool = False) -> Optional["plt.Figur
147149
lrs = self.results["lr"]
148150
losses = self.results["loss"]
149151

150-
fig, ax = plt.subplots()
152+
if ax is None:
153+
fig, ax = plt.subplots()
154+
else:
155+
fig = ax.figure
151156

152157
# Plot loss as a function of the learning rate
153158
ax.plot(lrs, losses)

0 commit comments

Comments
 (0)