Skip to content

Commit c7451b3

Browse files
tchatonSean Naren
and
Sean Naren
authored
[Feat] Add graceful detection of signal to exit + SignalConnector and merge SlurmConnector. (#9566)
Co-authored-by: Sean Naren <[email protected]>
1 parent 856ed10 commit c7451b3

File tree

7 files changed

+172
-66
lines changed

7 files changed

+172
-66
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
7070
* Added partial support for global random state fault-tolerance in map-style datasets ([#8950](https://github.com/PyTorchLightning/pytorch-lightning/pull/8950))
7171
* Converted state to tuple explicitly when setting Python random state ([#9401](https://github.com/PyTorchLightning/pytorch-lightning/pull/9401))
7272
* Added support for restarting an optimizer loop (multiple optimizers) ([#9537](https://github.com/PyTorchLightning/pytorch-lightning/pull/9537))
73+
* Added mechanism to detect a signal has been sent so the Trainer can gracefully exit ([#9566](https://github.com/PyTorchLightning/pytorch-lightning/pull/9566))
7374

7475

7576
- Checkpoint saving & loading extensibility:

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ module = [
6868
"pytorch_lightning.loops.evaluation_loop",
6969
"pytorch_lightning.trainer.connectors.checkpoint_connector",
7070
"pytorch_lightning.trainer.connectors.logger_connector.*",
71+
"pytorch_lightning.trainer.connectors.signal_connector",
7172
"pytorch_lightning.trainer.progress",
7273
"pytorch_lightning.tuner.auto_gpu_select",
7374
"pytorch_lightning.utilities.apply_func",

pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,14 +274,15 @@ def restore_lr_schedulers(self) -> None:
274274
# PRIVATE OPS
275275
# ----------------------------------
276276

277-
def hpc_save(self, folderpath: str, logger: LightningLoggerBase) -> str:
277+
def hpc_save(self, folderpath: str, logger: Optional[LightningLoggerBase]) -> str:
278278
# make sure the checkpoint folder exists
279279
folderpath = str(folderpath) # because the tests pass a path object
280280
fs = get_filesystem(folderpath)
281281
fs.makedirs(folderpath, exist_ok=True)
282282

283283
# save logger to make sure we get all the metrics
284-
logger.save()
284+
if logger:
285+
logger.finalize("finished")
285286

286287
max_suffix = self.max_ckpt_version_in_folder(folderpath)
287288
ckpt_number = (max_suffix if max_suffix is not None else 0) + 1
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
import logging
2+
import os
3+
import signal
4+
import sys
5+
from signal import Signals
6+
from subprocess import call
7+
from types import FrameType, FunctionType
8+
from typing import Callable, List, Union
9+
10+
import pytorch_lightning as pl
11+
from pytorch_lightning.utilities.imports import _fault_tolerant_training
12+
13+
log = logging.getLogger(__name__)
14+
15+
16+
class HandlersCompose:
17+
def __init__(self, signal_handlers: Union[List[Callable], Callable]):
18+
if not isinstance(signal_handlers, list):
19+
signal_handlers = [signal_handlers]
20+
self.signal_handlers = signal_handlers
21+
22+
def __call__(self, signum: Signals, frame: FrameType) -> None:
23+
for signal_handler in self.signal_handlers:
24+
signal_handler(signum, frame)
25+
26+
27+
class SignalConnector:
28+
def __init__(self, trainer: "pl.Trainer"):
29+
self.trainer = trainer
30+
self.trainer._terminate_gracefully = False
31+
32+
def register_signal_handlers(self) -> None:
33+
sigusr1_handlers: List[Callable] = []
34+
sigterm_handlers: List[Callable] = []
35+
36+
if _fault_tolerant_training():
37+
sigusr1_handlers.append(self.fault_tolerant_sigusr1_handler_fn)
38+
39+
if self._is_on_slurm():
40+
log.info("Set SLURM handle signals.")
41+
sigusr1_handlers.append(self.slurm_sigusr1_handler_fn)
42+
43+
sigterm_handlers.append(self.sigterm_handler_fn)
44+
45+
# signal.SIGUSR1 doesn't seem available on windows
46+
if not self._is_on_windows():
47+
if not self._has_already_handler(signal.SIGUSR1):
48+
signal.signal(signal.SIGUSR1, HandlersCompose(sigusr1_handlers))
49+
50+
if not self._has_already_handler(signal.SIGTERM):
51+
signal.signal(signal.SIGTERM, HandlersCompose(sigterm_handlers))
52+
53+
def slurm_sigusr1_handler_fn(self, signum: Signals, frame: FrameType) -> None:
54+
if self.trainer.is_global_zero:
55+
# save weights
56+
log.info("handling SIGUSR1")
57+
self.trainer.checkpoint_connector.hpc_save(self.trainer.weights_save_path, self.trainer.logger)
58+
59+
# find job id
60+
job_id = os.environ["SLURM_JOB_ID"]
61+
cmd = ["scontrol", "requeue", job_id]
62+
63+
# requeue job
64+
log.info(f"requeing job {job_id}...")
65+
try:
66+
result = call(cmd)
67+
except FileNotFoundError:
68+
# This can occur if a subprocess call to `scontrol` is run outside a shell context
69+
# Re-attempt call (now with shell context). If any error is raised, propagate to user.
70+
# When running a shell command, it should be passed as a single string.
71+
joint_cmd = [str(x) for x in cmd]
72+
result = call(" ".join(joint_cmd), shell=True)
73+
74+
# print result text
75+
if result == 0:
76+
log.info(f"requeued exp {job_id}")
77+
else:
78+
log.warning("requeue failed...")
79+
80+
# close experiment to avoid issues
81+
if self.trainer.logger:
82+
self.trainer.logger.finalize("finished")
83+
84+
def fault_tolerant_sigusr1_handler_fn(self, signum: Signals, frame: FrameType) -> None:
85+
self.trainer._terminate_gracefully = True
86+
87+
def sigterm_handler_fn(self, signum: Signals, frame: FrameType) -> None:
88+
log.info("bypassing sigterm")
89+
90+
def _is_on_slurm(self) -> bool:
91+
# see if we're using slurm (not interactive)
92+
on_slurm = False
93+
try:
94+
job_name = os.environ["SLURM_JOB_NAME"]
95+
if job_name != "bash":
96+
on_slurm = True
97+
# todo: specify the possible exception
98+
except Exception:
99+
pass
100+
101+
return on_slurm
102+
103+
def _is_on_windows(self) -> bool:
104+
return sys.platform == "win32"
105+
106+
def _has_already_handler(self, signum: Signals) -> bool:
107+
try:
108+
return isinstance(signal.getsignal(signum), FunctionType)
109+
except AttributeError:
110+
return False

pytorch_lightning/trainer/connectors/slurm_connector.py

Lines changed: 0 additions & 60 deletions
This file was deleted.

pytorch_lightning/trainer/trainer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
5353
from pytorch_lightning.trainer.connectors.model_connector import ModelConnector
5454
from pytorch_lightning.trainer.connectors.optimizer_connector import OptimizerConnector
55-
from pytorch_lightning.trainer.connectors.slurm_connector import SLURMConnector
55+
from pytorch_lightning.trainer.connectors.signal_connector import SignalConnector
5656
from pytorch_lightning.trainer.connectors.training_trick_connector import TrainingTricksConnector
5757
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
5858
from pytorch_lightning.trainer.deprecated_api import DeprecatedTrainerAttributes
@@ -383,7 +383,7 @@ def __init__(
383383
self.debugging_connector = DebuggingConnector(self)
384384
self.training_tricks_connector = TrainingTricksConnector(self)
385385
self.checkpoint_connector = CheckpointConnector(self, resume_from_checkpoint)
386-
self.slurm_connector = SLURMConnector(self)
386+
self.signal_connector = SignalConnector(self)
387387
self.tuner = Tuner(self)
388388

389389
# max_epochs won't default to 1000 if max_steps/max_time are specified (including being set to -1).
@@ -1104,8 +1104,8 @@ def _pre_training_routine(self):
11041104
# wait for all to join if on distributed
11051105
self.accelerator.barrier("setup_training")
11061106

1107-
# register auto-resubmit when on SLURM
1108-
self.slurm_connector.register_slurm_signal_handlers()
1107+
# register signals
1108+
self.signal_connector.register_signal_handlers()
11091109

11101110
self.checkpoint_connector.resume_end()
11111111

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import os
15+
import signal
16+
from time import sleep
17+
from unittest import mock
18+
19+
import pytest
20+
21+
from pytorch_lightning import Trainer
22+
from tests.helpers import BoringModel
23+
from tests.helpers.runif import RunIf
24+
25+
26+
@pytest.mark.parametrize("register_handler", [False, True])
27+
@pytest.mark.parametrize("terminate_gracefully", [False, True])
28+
@RunIf(min_torch="1.7.0", skip_windows=True)
29+
def test_fault_tolerant_sig_handler(register_handler, terminate_gracefully, tmpdir):
30+
31+
# hack to reset the signal
32+
signal.signal(signal.SIGUSR1, 0)
33+
34+
if register_handler:
35+
36+
def handler(*_):
37+
pass
38+
39+
signal.signal(signal.SIGUSR1, handler)
40+
41+
with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": str(int(terminate_gracefully))}):
42+
43+
class TestModel(BoringModel):
44+
def training_step(self, batch, batch_idx):
45+
if terminate_gracefully or register_handler:
46+
os.kill(os.getpid(), signal.SIGUSR1)
47+
sleep(0.1)
48+
return super().training_step(batch, batch_idx)
49+
50+
model = TestModel()
51+
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_train_batches=2, limit_val_batches=0)
52+
trainer.fit(model)
53+
assert trainer._terminate_gracefully == (False if register_handler else terminate_gracefully)

0 commit comments

Comments
 (0)