Skip to content

Commit 5e6db79

Browse files
carmoccalexierule
authored andcommitted
Squeeze the early stopping monitor (#10461)
1 parent 391e0d6 commit 5e6db79

File tree

3 files changed

+15
-2
lines changed

3 files changed

+15
-2
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3030
- Fixed an issue that prevented the Trainer to shutdown workers when execution is interrupted due to failure([#10463](https://github.com/PyTorchLightning/pytorch-lightning/issues/10463))
3131

3232

33-
-
33+
- Squeeze the early stopping monitor to remove empty tensor dimensions ([#10461](https://github.com/PyTorchLightning/pytorch-lightning/issues/10461))
3434

3535

3636
-

pytorch_lightning/callbacks/early_stopping.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def _run_early_stopping_check(self, trainer: "pl.Trainer") -> None:
202202
): # short circuit if metric not present
203203
return
204204

205-
current = logs.get(self.monitor)
205+
current = logs[self.monitor].squeeze()
206206
should_stop, reason = self._evaluate_stopping_criteria(current)
207207

208208
# stop every ddp process if any world process decides to stop

tests/callbacks/test_early_stopping.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,3 +469,16 @@ def validation_step(self, batch, batch_idx):
469469
assert trainer.global_step == len(side_effect) * int(trainer.limit_train_batches * trainer.val_check_interval)
470470
else:
471471
assert trainer.current_epoch == len(side_effect) * trainer.check_val_every_n_epoch - 1
472+
473+
474+
def test_early_stopping_squeezes():
475+
early_stopping = EarlyStopping(monitor="foo")
476+
trainer = Trainer()
477+
trainer.callback_metrics["foo"] = torch.tensor([[[0]]])
478+
479+
with mock.patch(
480+
"pytorch_lightning.callbacks.EarlyStopping._evaluate_stopping_criteria", return_value=(False, "")
481+
) as es_mock:
482+
early_stopping._run_early_stopping_check(trainer)
483+
484+
es_mock.assert_called_once_with(torch.tensor(0))

0 commit comments

Comments
 (0)