-
Notifications
You must be signed in to change notification settings - Fork 3.5k
[Feat] Add graceful detection of signal to exit + SignalConnector and merge SlurmConnector. #9566
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 10 commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
b637fc1
update
tchaton 47cc2ec
update
tchaton 83da52f
cleanup
tchaton e0f5004
address comments
tchaton 6f26558
update
tchaton bea0adc
wip
tchaton d685bc3
update
tchaton dd7002f
update
tchaton 82d3192
update
tchaton dc62236
update
tchaton 5f677d4
update
tchaton 3e522a3
update
tchaton 11b2b93
update
tchaton a8a7ac5
resolve mypy
tchaton 1af0fdd
Update CHANGELOG.md
tchaton 5034bd9
update
tchaton 9130d92
update
tchaton 9fc7d0a
resolve mypy
tchaton c100d1e
update
tchaton 7216a39
update
tchaton File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
121 changes: 121 additions & 0 deletions
121
pytorch_lightning/trainer/connectors/signal_connector.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
import logging | ||
tchaton marked this conversation as resolved.
Show resolved
Hide resolved
|
||
import os | ||
import signal | ||
from signal import Signals | ||
from subprocess import call | ||
from typing import Callable, List, Optional, Union | ||
|
||
import _signal | ||
|
||
from pytorch_lightning.utilities.imports import _fault_tolerant_training | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
||
class HandlersCompose: | ||
def __init__(self, signal_handlers: Union[List[Callable], Callable]): | ||
if not isinstance(signal_handlers, list): | ||
signal_handlers = [signal_handlers] | ||
self.signal_handlers = signal_handlers | ||
|
||
def __call__(self, signum, frame): | ||
for signal_handler in self.signal_handlers: | ||
signal_handler(signum, frame) | ||
|
||
|
||
class SignalConnector: | ||
def __init__(self, trainer, sigusr1_handler: Optional[Callable] = None, sigterm_handler: Optional[Callable] = None): | ||
self.trainer = trainer | ||
self.trainer._terminate_gracefully = False | ||
self._sigusr1_handler = sigusr1_handler | ||
self._sigterm_handler = sigterm_handler | ||
|
||
@property | ||
def sigusr1_handler(self) -> Optional[Callable]: | ||
return self._sigusr1_handler | ||
|
||
@sigusr1_handler.setter | ||
def sigusr1_handler(self, sigusr1_handler: Callable) -> None: | ||
self._sigusr1_handler = sigusr1_handler | ||
|
||
@property | ||
def sigterm_handler(self) -> Optional[Callable]: | ||
return self._sigterm_handler | ||
|
||
@sigterm_handler.setter | ||
def sigterm_handler(self, sigterm_handler: Callable) -> None: | ||
self._sigterm_handler = sigterm_handler | ||
tchaton marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def register_signal_handlers(self): | ||
sigusr1_handlers = [] | ||
sigterm_handlers = [] | ||
|
||
if _fault_tolerant_training(): | ||
sigusr1_handlers.append(self.fault_tolerant_sigusr1_handler_fn) | ||
tchaton marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
if self._is_on_slurm(): | ||
log.info("Set SLURM handle signals.") | ||
sigusr1_handlers.append(self.slurm_sigusr1_handler_fn) | ||
|
||
sigterm_handlers.append(self.sigterm_handler_fn) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this line got unindented, it was previously under the slurm check. See #10154 for context. |
||
|
||
if not self._has_already_handler(signal.SIGUSR1): | ||
signal.signal(signal.SIGUSR1, HandlersCompose(self.sigusr1_handler or sigusr1_handlers)) | ||
|
||
if not self._has_already_handler(signal.SIGTERM): | ||
signal.signal(signal.SIGTERM, HandlersCompose(self.sigterm_handler or sigterm_handlers)) | ||
|
||
def _is_on_slurm(self) -> bool: | ||
# see if we're using slurm (not interactive) | ||
on_slurm = False | ||
try: | ||
job_name = os.environ["SLURM_JOB_NAME"] | ||
if job_name != "bash": | ||
on_slurm = True | ||
# todo: specify the possible exception | ||
except Exception: | ||
pass | ||
|
||
return on_slurm | ||
|
||
def slurm_sigusr1_handler_fn(self, signum, frame): # pragma: no-cover | ||
if self.trainer.is_global_zero: | ||
# save weights | ||
log.info("handling SIGUSR1") | ||
self.trainer.checkpoint_connector.hpc_save(self.trainer.weights_save_path, self.trainer.logger) | ||
tchaton marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# find job id | ||
job_id = os.environ["SLURM_JOB_ID"] | ||
cmd = ["scontrol", "requeue", job_id] | ||
|
||
# requeue job | ||
log.info(f"requeing job {job_id}...") | ||
try: | ||
result = call(cmd) | ||
except FileNotFoundError: | ||
# This can occur if a subprocess call to `scontrol` is run outside a shell context | ||
# Re-attempt call (now with shell context). If any error is raised, propagate to user. | ||
# When running a shell command, it should be passed as a single string. | ||
joint_cmd = [str(x) for x in cmd] | ||
result = call(" ".join(joint_cmd), shell=True) | ||
|
||
# print result text | ||
if result == 0: | ||
log.info(f"requeued exp {job_id}") | ||
else: | ||
log.warning("requeue failed...") | ||
|
||
# close experiment to avoid issues | ||
self.trainer.logger.close() | ||
|
||
def fault_tolerant_sigusr1_handler_fn(self, signum, frame): # pragma: no-cover | ||
self.trainer._terminate_gracefully = True | ||
tchaton marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def sigterm_handler_fn(self, signum, frame): # pragma: no-cover | ||
log.info("bypassing sigterm") | ||
tchaton marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def _has_already_handler(self, signal: Signals) -> bool: | ||
try: | ||
return isinstance(_signal.getsignal(signal), Callable) | ||
except AttributeError: | ||
return False | ||
tchaton marked this conversation as resolved.
Show resolved
Hide resolved
|
This file was deleted.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import os | ||
import signal | ||
from time import sleep | ||
from unittest import mock | ||
|
||
import pytest | ||
|
||
from pytorch_lightning import Trainer | ||
from tests.helpers import BoringModel | ||
from tests.helpers.runif import RunIf | ||
|
||
|
||
@pytest.mark.parametrize("register_handler", [False, True]) | ||
@pytest.mark.parametrize("terminate_gracefully", [False, True]) | ||
@RunIf(min_torch="1.7.0", skip_windows=True) | ||
def test_fault_tolerant_sig_handler(register_handler, terminate_gracefully, tmpdir): | ||
|
||
# hack to reset the signal | ||
signal.signal(signal.SIGUSR1, 0) | ||
|
||
if register_handler: | ||
|
||
def handler(*_): | ||
pass | ||
|
||
signal.signal(signal.SIGUSR1, handler) | ||
|
||
with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": str(int(terminate_gracefully))}): | ||
|
||
class TestModel(BoringModel): | ||
def training_step(self, batch, batch_idx): | ||
if terminate_gracefully or register_handler: | ||
os.kill(os.getpid(), signal.SIGUSR1) | ||
sleep(0.1) | ||
return super().training_step(batch, batch_idx) | ||
|
||
model = TestModel() | ||
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_train_batches=2, limit_val_batches=0) | ||
trainer.fit(model) | ||
assert trainer._terminate_gracefully == (False if register_handler else terminate_gracefully) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.