Skip to content

Commit bf1394a

Browse files
authored
improve early stopping verbose logging (#6811)
1 parent 393b252 commit bf1394a

File tree

2 files changed

+26
-3
lines changed

2 files changed

+26
-3
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
145145
- Added warning when missing `Callback` and using `resume_from_checkpoint` ([#7254](https://github.com/PyTorchLightning/pytorch-lightning/pull/7254))
146146

147147

148+
- Improved verbose logging for `EarlyStopping` callback ([#6811](https://github.com/PyTorchLightning/pytorch-lightning/pull/6811))
149+
150+
148151
### Changed
149152

150153

pytorch_lightning/callbacks/early_stopping.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import numpy as np
2525
import torch
2626

27+
import pytorch_lightning as pl
2728
from pytorch_lightning.callbacks.base import Callback
2829
from pytorch_lightning.utilities import rank_zero_warn
2930
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@@ -196,8 +197,8 @@ def _run_early_stopping_check(self, trainer) -> None:
196197
trainer.should_stop = trainer.should_stop or should_stop
197198
if should_stop:
198199
self.stopped_epoch = trainer.current_epoch
199-
if reason:
200-
log.info(f"[{trainer.global_rank}] {reason}")
200+
if reason and self.verbose:
201+
self._log_info(trainer, reason)
201202

202203
def _evalute_stopping_criteria(self, current: torch.Tensor) -> Tuple[bool, str]:
203204
should_stop = False
@@ -224,15 +225,34 @@ def _evalute_stopping_criteria(self, current: torch.Tensor) -> Tuple[bool, str]:
224225
)
225226
elif self.monitor_op(current - self.min_delta, self.best_score):
226227
should_stop = False
228+
reason = self._improvement_message(current)
227229
self.best_score = current
228230
self.wait_count = 0
229231
else:
230232
self.wait_count += 1
231233
if self.wait_count >= self.patience:
232234
should_stop = True
233235
reason = (
234-
f"Monitored metric {self.monitor} did not improve in the last {self.wait_count} epochs."
236+
f"Monitored metric {self.monitor} did not improve in the last {self.wait_count} records."
235237
f" Best score: {self.best_score:.3f}. Signaling Trainer to stop."
236238
)
237239

238240
return should_stop, reason
241+
242+
def _improvement_message(self, current: torch.Tensor) -> str:
243+
""" Formats a log message that informs the user about an improvement in the monitored score. """
244+
if torch.isfinite(self.best_score):
245+
msg = (
246+
f"Metric {self.monitor} improved by {abs(self.best_score - current):.3f} >="
247+
f" min_delta = {abs(self.min_delta)}. New best score: {current:.3f}"
248+
)
249+
else:
250+
msg = f"Metric {self.monitor} improved. New best score: {current:.3f}"
251+
return msg
252+
253+
@staticmethod
254+
def _log_info(trainer: Optional["pl.Trainer"], message: str) -> None:
255+
if trainer is not None and trainer.world_size > 1:
256+
log.info(f"[rank: {trainer.global_rank}] {message}")
257+
else:
258+
log.info(message)

0 commit comments

Comments
 (0)