Skip to content

Commit b046bd0

Browse files
Add on_exception callback hook (#9183)
1 parent ff7305f commit b046bd0

File tree

9 files changed

+37
-8
lines changed

9 files changed

+37
-8
lines changed

CHANGELOG.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
9797
- Added validate logic for precision ([#9080](https://github.com/PyTorchLightning/pytorch-lightning/pull/9080))
9898

9999

100-
- Add support for CPU AMP autocast ([#9084](https://github.com/PyTorchLightning/pytorch-lightning/pull/9084))
100+
- Added support for CPU AMP autocast ([#9084](https://github.com/PyTorchLightning/pytorch-lightning/pull/9084))
101101

102102

103+
- Added `on_exception` callback hook ([#9183](https://github.com/PyTorchLightning/pytorch-lightning/pull/9183))
104+
103105
### Changed
104106

105107
- Parsing of the `gpus` Trainer argument has changed: `gpus="n"` (str) no longer selects the GPU index n and instead selects the first n devices. ([#8770](https://github.com/PyTorchLightning/pytorch-lightning/pull/8770))
@@ -167,7 +169,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
167169

168170
- Deprecated `on_{train/val/test/predict}_dataloader()` from `LightningModule` and `LightningDataModule` [#9098](https://github.com/PyTorchLightning/pytorch-lightning/pull/9098)
169171

170-
-
171172

172173
- Updated deprecation of `argparse_utils.py` from removal in 1.4 to 2.0 ([#9162](https://github.com/PyTorchLightning/pytorch-lightning/pull/9162))
173174

pytorch_lightning/callbacks/base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,10 @@ def on_keyboard_interrupt(self, trainer: "pl.Trainer", pl_module: "pl.LightningM
267267
"""Called when the training is interrupted by ``KeyboardInterrupt``."""
268268
pass
269269

270+
def on_exception(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", exception: BaseException) -> None:
271+
"""Called when any trainer execution is interrupted by an exception."""
272+
pass
273+
270274
def on_save_checkpoint(
271275
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
272276
) -> dict:

pytorch_lightning/callbacks/lambda_function.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def __init__(
7575
on_test_start: Optional[Callable] = None,
7676
on_test_end: Optional[Callable] = None,
7777
on_keyboard_interrupt: Optional[Callable] = None,
78+
on_exception: Optional[Callable] = None,
7879
on_save_checkpoint: Optional[Callable] = None,
7980
on_load_checkpoint: Optional[Callable] = None,
8081
on_before_backward: Optional[Callable] = None,

pytorch_lightning/trainer/callback_hook.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,11 @@ def on_keyboard_interrupt(self):
236236
for callback in self.callbacks:
237237
callback.on_keyboard_interrupt(self, self.lightning_module)
238238

239+
def on_exception(self, exception: BaseException) -> None:
240+
"""Called when any trainer execution is interrupted by an exception."""
241+
for callback in self.callbacks:
242+
callback.on_exception(self, self.lightning_module, exception)
243+
239244
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[str, dict]:
240245
"""Called when saving a model checkpoint."""
241246
callback_states = {}

pytorch_lightning/trainer/configuration_validator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@
1313
# limitations under the License.
1414
import pytorch_lightning as pl
1515
from pytorch_lightning.trainer.states import TrainerFn
16-
from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn
1716
from pytorch_lightning.utilities.exceptions import MisconfigurationException
1817
from pytorch_lightning.utilities.model_helpers import is_overridden
1918
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
19+
from pytorch_lightning.utilities.warnings import rank_zero_deprecation, rank_zero_warn
2020

2121

2222
class ConfigValidator:

pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ class FxValidator:
6262
on_predict_batch_start=None,
6363
on_predict_batch_end=None,
6464
on_keyboard_interrupt=None,
65+
on_exception=None,
6566
on_save_checkpoint=None,
6667
on_load_checkpoint=None,
6768
setup=None,

pytorch_lightning/trainer/trainer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -505,20 +505,22 @@ def _call_and_handle_interrupt(self, trainer_fn: Callable, *args: Any, **kwargs:
505505
"""
506506
try:
507507
return trainer_fn(*args, **kwargs)
508-
except KeyboardInterrupt:
508+
except KeyboardInterrupt as exception:
509509
rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
510510
# user could press Ctrl+c many times... only shutdown once
511511
if not self.interrupted:
512512
self.state.status = TrainerStatus.INTERRUPTED
513513
self.on_keyboard_interrupt()
514-
except BaseException:
514+
self.on_exception(exception)
515+
except BaseException as exception:
515516
self.state.status = TrainerStatus.INTERRUPTED
516517
if distributed_available() and self.world_size > 1:
517518
# try syncing remaing processes, kill otherwise
518519
self.training_type_plugin.reconciliate_processes(traceback.format_exc())
519520
self._on_exception()
520521
# reset bookkeeping
521522
self.state.stage = None
523+
self.on_exception(exception)
522524
raise
523525

524526
def fit(

tests/trainer/logging_/test_logger_connector.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def test_fx_validator(tmpdir):
4949
"on_init_end",
5050
"on_init_start",
5151
"on_keyboard_interrupt",
52+
"on_exception",
5253
"on_load_checkpoint",
5354
"on_pretrain_routine_end",
5455
"on_pretrain_routine_start",
@@ -91,6 +92,7 @@ def test_fx_validator(tmpdir):
9192
"on_init_end",
9293
"on_init_start",
9394
"on_keyboard_interrupt",
95+
"on_exception",
9496
"on_load_checkpoint",
9597
"on_pretrain_routine_end",
9698
"on_pretrain_routine_start",

tests/trainer/test_trainer.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -834,10 +834,10 @@ def on_after_backward(self):
834834
assert not torch.isfinite(params).all()
835835

836836

837-
def test_trainer_interrupted_flag(tmpdir):
838-
"""Test the flag denoting that a user interrupted training."""
837+
def test_on_exception_hook(tmpdir):
838+
"""Test the on_exception callback hook and the trainer interrupted flag."""
839839

840-
model = EvalModelTemplate()
840+
model = BoringModel()
841841

842842
class InterruptCallback(Callback):
843843
def __init__(self):
@@ -846,11 +846,18 @@ def __init__(self):
846846
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
847847
raise KeyboardInterrupt
848848

849+
def on_test_start(self, trainer, pl_module):
850+
raise MisconfigurationException
851+
849852
class HandleInterruptCallback(Callback):
850853
def __init__(self):
851854
super().__init__()
855+
self.exception = None
852856
self.exc_info = None
853857

858+
def on_exception(self, trainer, pl_module, exception):
859+
self.exception = exception
860+
854861
def on_keyboard_interrupt(self, trainer, pl_module):
855862
self.exc_info = sys.exc_info()
856863

@@ -867,10 +874,16 @@ def on_keyboard_interrupt(self, trainer, pl_module):
867874
default_root_dir=tmpdir,
868875
)
869876
assert not trainer.interrupted
877+
assert handle_interrupt_callback.exception is None
870878
assert handle_interrupt_callback.exc_info is None
871879
trainer.fit(model)
872880
assert trainer.interrupted
881+
assert isinstance(handle_interrupt_callback.exception, KeyboardInterrupt)
873882
assert isinstance(handle_interrupt_callback.exc_info[1], KeyboardInterrupt)
883+
with pytest.raises(MisconfigurationException):
884+
trainer.test(model)
885+
assert trainer.interrupted
886+
assert isinstance(handle_interrupt_callback.exception, MisconfigurationException)
874887

875888

876889
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)