From b637fc110b33f5a9531e6999e5d59346e7bd3878 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 16 Sep 2021 15:20:19 +0100 Subject: [PATCH 01/19] update --- .../connectors/fault_tolerant_connector.py | 23 ++++++++++ pytorch_lightning/trainer/trainer.py | 5 ++ .../connectors/fault_tolerant_pid_killer.py | 17 +++++++ .../test_fault_tolerant_connector.py | 46 +++++++++++++++++++ 4 files changed, 91 insertions(+) create mode 100644 pytorch_lightning/trainer/connectors/fault_tolerant_connector.py create mode 100644 tests/trainer/connectors/fault_tolerant_pid_killer.py create mode 100644 tests/trainer/connectors/test_fault_tolerant_connector.py diff --git a/pytorch_lightning/trainer/connectors/fault_tolerant_connector.py b/pytorch_lightning/trainer/connectors/fault_tolerant_connector.py new file mode 100644 index 0000000000000..081a7404c5d5c --- /dev/null +++ b/pytorch_lightning/trainer/connectors/fault_tolerant_connector.py @@ -0,0 +1,23 @@ +import logging +import signal + +from pytorch_lightning.utilities.imports import _fault_tolerant_training + +log = logging.getLogger(__name__) + + +class FaultTolerantConnector: + def __init__(self, trainer): + self.trainer = trainer + self.trainer._should_gracefully_terminate = False + + def register_fault_tolerant_signal_handlers(self): + if _fault_tolerant_training(): + signal.signal(signal.SIGUSR1, self.sig_handler) + signal.signal(signal.SIGTERM, self.term_handler) + + def sig_handler(self, signum, frame): # pragma: no-cover + self.trainer._should_gracefully_terminate = True + + def term_handler(self, signum, frame): # pragma: no-cover + log.info("bypassing sigterm") diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index d705925aa8dfe..afa934a68f7fe 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -49,6 +49,7 @@ from pytorch_lightning.trainer.connectors.data_connector import DataConnector from pytorch_lightning.trainer.connectors.debugging_connector import DebuggingConnector from pytorch_lightning.trainer.connectors.env_vars_connector import _defaults_from_env_vars +from pytorch_lightning.trainer.connectors.fault_tolerant_connector import FaultTolerantConnector from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector from pytorch_lightning.trainer.connectors.model_connector import ModelConnector from pytorch_lightning.trainer.connectors.optimizer_connector import OptimizerConnector @@ -384,6 +385,7 @@ def __init__( self.training_tricks_connector = TrainingTricksConnector(self) self.checkpoint_connector = CheckpointConnector(self, resume_from_checkpoint) self.slurm_connector = SLURMConnector(self) + self.fault_tolerant_connector = FaultTolerantConnector(self) self.tuner = Tuner(self) # max_epochs won't default to 1000 if max_steps/max_time are specified (including being set to -1). @@ -1099,6 +1101,9 @@ def _pre_training_routine(self): # register auto-resubmit when on SLURM self.slurm_connector.register_slurm_signal_handlers() + # used to register gracefully detection of signal to ensure clean exit mechanism + self.fault_tolerant_connector.register_fault_tolerant_signal_handlers() + self.checkpoint_connector.resume_end() # -------------------------- diff --git a/tests/trainer/connectors/fault_tolerant_pid_killer.py b/tests/trainer/connectors/fault_tolerant_pid_killer.py new file mode 100644 index 0000000000000..33c4692d7fa6c --- /dev/null +++ b/tests/trainer/connectors/fault_tolerant_pid_killer.py @@ -0,0 +1,17 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import signal + +os.kill(int(os.getenv("PID", None)), signal.SIGUSR1) diff --git a/tests/trainer/connectors/test_fault_tolerant_connector.py b/tests/trainer/connectors/test_fault_tolerant_connector.py new file mode 100644 index 0000000000000..84968944f3673 --- /dev/null +++ b/tests/trainer/connectors/test_fault_tolerant_connector.py @@ -0,0 +1,46 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import subprocess +import sys +from time import sleep +from unittest import mock + +import pytest + +from pytorch_lightning import Trainer +from tests.helpers import BoringModel +from tests.helpers.runif import RunIf + + +@pytest.mark.parametrize("should_gracefully_terminate", [False, True]) +@RunIf(min_torch="1.7.0", special=True) +def test_fault_tolerant_sig_handler(should_gracefully_terminate, tmpdir): + + with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": str(int(should_gracefully_terminate))}): + + class TestModel(BoringModel): + def training_step(self, batch, batch_idx): + if should_gracefully_terminate and self.trainer.current_epoch == 1 and batch_idx == 1: + env_copy = os.environ.copy() + env_copy["PID"] = str(os.getpid()) + command = [sys.executable, os.path.join(os.path.dirname(__file__), "fault_tolerant_pid_killer.py")] + subprocess.Popen(command, env=env_copy) + sleep(0.1) + return super().training_step(batch, batch_idx) + + model = TestModel() + trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, limit_train_batches=2, limit_val_batches=2) + trainer.fit(model) + assert trainer._should_gracefully_terminate == should_gracefully_terminate From 47cc2eca88d70611c350354756f46c66b69a3940 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 16 Sep 2021 15:24:34 +0100 Subject: [PATCH 02/19] update --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 42a720e5c0d1a..1b1f6c483a5b2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -65,6 +65,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Added partial support for global random state fault-tolerance in map-style datasets ([#8950](https://github.com/PyTorchLightning/pytorch-lightning/pull/8950)) * Converted state to tuple explicitly when setting Python random state ([#9401](https://github.com/PyTorchLightning/pytorch-lightning/pull/9401)) * Added support for restarting an optimizer loop (multiple optimizers) ([#9537](https://github.com/PyTorchLightning/pytorch-lightning/pull/9537)) + * Added a mechanism to detect a signal as been sent so the Trainer can gracefully exit ([#9566](https://github.com/PyTorchLightning/pytorch-lightning/pull/9566)) - Checkpoint saving & loading extensibility: From 83da52f58e6d01a08c080d9370149d6b586b0367 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 16 Sep 2021 15:48:19 +0100 Subject: [PATCH 03/19] cleanup --- ...olerant_connector.py => signal_connector.py} | 8 +++++--- pytorch_lightning/trainer/trainer.py | 7 +++---- .../connectors/fault_tolerant_pid_killer.py | 17 ----------------- ...nt_connector.py => test_signal_connector.py} | 8 ++------ 4 files changed, 10 insertions(+), 30 deletions(-) rename pytorch_lightning/trainer/connectors/{fault_tolerant_connector.py => signal_connector.py} (65%) delete mode 100644 tests/trainer/connectors/fault_tolerant_pid_killer.py rename tests/trainer/connectors/{test_fault_tolerant_connector.py => test_signal_connector.py} (83%) diff --git a/pytorch_lightning/trainer/connectors/fault_tolerant_connector.py b/pytorch_lightning/trainer/connectors/signal_connector.py similarity index 65% rename from pytorch_lightning/trainer/connectors/fault_tolerant_connector.py rename to pytorch_lightning/trainer/connectors/signal_connector.py index 081a7404c5d5c..5418d3fc2d966 100644 --- a/pytorch_lightning/trainer/connectors/fault_tolerant_connector.py +++ b/pytorch_lightning/trainer/connectors/signal_connector.py @@ -1,18 +1,20 @@ import logging import signal +from pytorch_lightning.plugins.environments import SLURMEnvironment from pytorch_lightning.utilities.imports import _fault_tolerant_training log = logging.getLogger(__name__) -class FaultTolerantConnector: +class SignalConnector: def __init__(self, trainer): self.trainer = trainer self.trainer._should_gracefully_terminate = False - def register_fault_tolerant_signal_handlers(self): - if _fault_tolerant_training(): + def register_signal_handlers(self): + cluster_env = getattr(self.trainer.training_type_plugin, "cluster_environment", None) + if _fault_tolerant_training() and not isinstance(cluster_env, SLURMEnvironment): signal.signal(signal.SIGUSR1, self.sig_handler) signal.signal(signal.SIGTERM, self.term_handler) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index afa934a68f7fe..9de5fc1eca6a0 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -49,10 +49,10 @@ from pytorch_lightning.trainer.connectors.data_connector import DataConnector from pytorch_lightning.trainer.connectors.debugging_connector import DebuggingConnector from pytorch_lightning.trainer.connectors.env_vars_connector import _defaults_from_env_vars -from pytorch_lightning.trainer.connectors.fault_tolerant_connector import FaultTolerantConnector from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector from pytorch_lightning.trainer.connectors.model_connector import ModelConnector from pytorch_lightning.trainer.connectors.optimizer_connector import OptimizerConnector +from pytorch_lightning.trainer.connectors.signal_connector import SignalConnector from pytorch_lightning.trainer.connectors.slurm_connector import SLURMConnector from pytorch_lightning.trainer.connectors.training_trick_connector import TrainingTricksConnector from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin @@ -385,7 +385,7 @@ def __init__( self.training_tricks_connector = TrainingTricksConnector(self) self.checkpoint_connector = CheckpointConnector(self, resume_from_checkpoint) self.slurm_connector = SLURMConnector(self) - self.fault_tolerant_connector = FaultTolerantConnector(self) + self.signal_connector = SignalConnector(self) self.tuner = Tuner(self) # max_epochs won't default to 1000 if max_steps/max_time are specified (including being set to -1). @@ -1101,8 +1101,7 @@ def _pre_training_routine(self): # register auto-resubmit when on SLURM self.slurm_connector.register_slurm_signal_handlers() - # used to register gracefully detection of signal to ensure clean exit mechanism - self.fault_tolerant_connector.register_fault_tolerant_signal_handlers() + self.signal_connector.register_signal_handlers() self.checkpoint_connector.resume_end() diff --git a/tests/trainer/connectors/fault_tolerant_pid_killer.py b/tests/trainer/connectors/fault_tolerant_pid_killer.py deleted file mode 100644 index 33c4692d7fa6c..0000000000000 --- a/tests/trainer/connectors/fault_tolerant_pid_killer.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os -import signal - -os.kill(int(os.getenv("PID", None)), signal.SIGUSR1) diff --git a/tests/trainer/connectors/test_fault_tolerant_connector.py b/tests/trainer/connectors/test_signal_connector.py similarity index 83% rename from tests/trainer/connectors/test_fault_tolerant_connector.py rename to tests/trainer/connectors/test_signal_connector.py index 84968944f3673..9462881ebd497 100644 --- a/tests/trainer/connectors/test_fault_tolerant_connector.py +++ b/tests/trainer/connectors/test_signal_connector.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -import subprocess -import sys +import signal from time import sleep from unittest import mock @@ -33,10 +32,7 @@ def test_fault_tolerant_sig_handler(should_gracefully_terminate, tmpdir): class TestModel(BoringModel): def training_step(self, batch, batch_idx): if should_gracefully_terminate and self.trainer.current_epoch == 1 and batch_idx == 1: - env_copy = os.environ.copy() - env_copy["PID"] = str(os.getpid()) - command = [sys.executable, os.path.join(os.path.dirname(__file__), "fault_tolerant_pid_killer.py")] - subprocess.Popen(command, env=env_copy) + os.kill(os.getpid(), signal.SIGUSR1) sleep(0.1) return super().training_step(batch, batch_idx) From e0f500463c692f37940daf224e833afcb62f239b Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 16 Sep 2021 15:54:11 +0100 Subject: [PATCH 04/19] address comments --- .../trainer/connectors/signal_connector.py | 59 +++++++++++++++++- .../trainer/connectors/slurm_connector.py | 60 ------------------- pytorch_lightning/trainer/trainer.py | 6 +- 3 files changed, 57 insertions(+), 68 deletions(-) delete mode 100644 pytorch_lightning/trainer/connectors/slurm_connector.py diff --git a/pytorch_lightning/trainer/connectors/signal_connector.py b/pytorch_lightning/trainer/connectors/signal_connector.py index 5418d3fc2d966..07703f0680c7a 100644 --- a/pytorch_lightning/trainer/connectors/signal_connector.py +++ b/pytorch_lightning/trainer/connectors/signal_connector.py @@ -1,5 +1,7 @@ import logging +import os import signal +from subprocess import call from pytorch_lightning.plugins.environments import SLURMEnvironment from pytorch_lightning.utilities.imports import _fault_tolerant_training @@ -14,11 +16,62 @@ def __init__(self, trainer): def register_signal_handlers(self): cluster_env = getattr(self.trainer.training_type_plugin, "cluster_environment", None) - if _fault_tolerant_training() and not isinstance(cluster_env, SLURMEnvironment): - signal.signal(signal.SIGUSR1, self.sig_handler) + if isinstance(cluster_env, SLURMEnvironment): + self.register_slurm_signal_handlers() + elif _fault_tolerant_training(): + self.register_fault_tolerant_handlers() + + def register_fault_tolerant_handlers(self): + signal.signal(signal.SIGUSR1, self.sig_fault_tolerant_handler) + signal.signal(signal.SIGTERM, self.term_handler) + + def register_slurm_signal_handlers(self): + # see if we're using slurm (not interactive) + on_slurm = False + try: + job_name = os.environ["SLURM_JOB_NAME"] + if job_name != "bash": + on_slurm = True + # todo: specify the possible exception + except Exception: + pass + + if on_slurm: + log.info("Set SLURM handle signals.") + signal.signal(signal.SIGUSR1, self.sig_slurm_handler) signal.signal(signal.SIGTERM, self.term_handler) - def sig_handler(self, signum, frame): # pragma: no-cover + def sig_slurm_handler(self, signum, frame): # pragma: no-cover + if self.trainer.is_global_zero: + # save weights + log.info("handling SIGUSR1") + self.trainer.checkpoint_connector.hpc_save(self.trainer.weights_save_path, self.trainer.logger) + + # find job id + job_id = os.environ["SLURM_JOB_ID"] + cmd = ["scontrol", "requeue", job_id] + + # requeue job + log.info(f"requeing job {job_id}...") + try: + result = call(cmd) + except FileNotFoundError: + # This can occur if a subprocess call to `scontrol` is run outside a shell context + # Re-attempt call (now with shell context). If any error is raised, propagate to user. + # When running a shell command, it should be passed as a single string. + joint_cmd = [str(x) for x in cmd] + result = call(" ".join(joint_cmd), shell=True) + + # print result text + if result == 0: + log.info(f"requeued exp {job_id}") + else: + log.warning("requeue failed...") + + # close experiment to avoid issues + self.trainer.logger.close() + + def sig_fault_tolerant_handler(self, signum, frame): # pragma: no-cover self.trainer._should_gracefully_terminate = True def term_handler(self, signum, frame): # pragma: no-cover diff --git a/pytorch_lightning/trainer/connectors/slurm_connector.py b/pytorch_lightning/trainer/connectors/slurm_connector.py deleted file mode 100644 index 053e1397ba2a2..0000000000000 --- a/pytorch_lightning/trainer/connectors/slurm_connector.py +++ /dev/null @@ -1,60 +0,0 @@ -import logging -import os -import signal -from subprocess import call - -log = logging.getLogger(__name__) - - -class SLURMConnector: - def __init__(self, trainer): - self.trainer = trainer - - def register_slurm_signal_handlers(self): - # see if we're using slurm (not interactive) - on_slurm = False - try: - job_name = os.environ["SLURM_JOB_NAME"] - if job_name != "bash": - on_slurm = True - # todo: specify the possible exception - except Exception: - pass - - if on_slurm: - log.info("Set SLURM handle signals.") - signal.signal(signal.SIGUSR1, self.sig_handler) - signal.signal(signal.SIGTERM, self.term_handler) - - def sig_handler(self, signum, frame): # pragma: no-cover - if self.trainer.is_global_zero: - # save weights - log.info("handling SIGUSR1") - self.trainer.checkpoint_connector.hpc_save(self.trainer.weights_save_path, self.trainer.logger) - - # find job id - job_id = os.environ["SLURM_JOB_ID"] - cmd = ["scontrol", "requeue", job_id] - - # requeue job - log.info(f"requeing job {job_id}...") - try: - result = call(cmd) - except FileNotFoundError: - # This can occur if a subprocess call to `scontrol` is run outside a shell context - # Re-attempt call (now with shell context). If any error is raised, propagate to user. - # When running a shell command, it should be passed as a single string. - joint_cmd = [str(x) for x in cmd] - result = call(" ".join(joint_cmd), shell=True) - - # print result text - if result == 0: - log.info(f"requeued exp {job_id}") - else: - log.warning("requeue failed...") - - # close experiment to avoid issues - self.trainer.logger.close() - - def term_handler(self, signum, frame): # pragma: no-cover - log.info("bypassing sigterm") diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 9de5fc1eca6a0..2a6302fe91a35 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -53,7 +53,6 @@ from pytorch_lightning.trainer.connectors.model_connector import ModelConnector from pytorch_lightning.trainer.connectors.optimizer_connector import OptimizerConnector from pytorch_lightning.trainer.connectors.signal_connector import SignalConnector -from pytorch_lightning.trainer.connectors.slurm_connector import SLURMConnector from pytorch_lightning.trainer.connectors.training_trick_connector import TrainingTricksConnector from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin from pytorch_lightning.trainer.deprecated_api import DeprecatedTrainerAttributes @@ -384,7 +383,6 @@ def __init__( self.debugging_connector = DebuggingConnector(self) self.training_tricks_connector = TrainingTricksConnector(self) self.checkpoint_connector = CheckpointConnector(self, resume_from_checkpoint) - self.slurm_connector = SLURMConnector(self) self.signal_connector = SignalConnector(self) self.tuner = Tuner(self) @@ -1098,9 +1096,7 @@ def _pre_training_routine(self): # wait for all to join if on distributed self.accelerator.barrier("setup_training") - # register auto-resubmit when on SLURM - self.slurm_connector.register_slurm_signal_handlers() - + # register signals self.signal_connector.register_signal_handlers() self.checkpoint_connector.resume_end() From 6f26558c31b25a1936c8d756c97be438b3542e60 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 16 Sep 2021 16:40:20 +0100 Subject: [PATCH 05/19] update --- pytorch_lightning/trainer/connectors/signal_connector.py | 3 ++- tests/trainer/connectors/test_signal_connector.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/signal_connector.py b/pytorch_lightning/trainer/connectors/signal_connector.py index 07703f0680c7a..feb88d8f8d506 100644 --- a/pytorch_lightning/trainer/connectors/signal_connector.py +++ b/pytorch_lightning/trainer/connectors/signal_connector.py @@ -1,6 +1,7 @@ import logging import os import signal +import sys from subprocess import call from pytorch_lightning.plugins.environments import SLURMEnvironment @@ -18,7 +19,7 @@ def register_signal_handlers(self): cluster_env = getattr(self.trainer.training_type_plugin, "cluster_environment", None) if isinstance(cluster_env, SLURMEnvironment): self.register_slurm_signal_handlers() - elif _fault_tolerant_training(): + elif _fault_tolerant_training() and not sys.platform == "win32": self.register_fault_tolerant_handlers() def register_fault_tolerant_handlers(self): diff --git a/tests/trainer/connectors/test_signal_connector.py b/tests/trainer/connectors/test_signal_connector.py index 9462881ebd497..17d0b9f9516af 100644 --- a/tests/trainer/connectors/test_signal_connector.py +++ b/tests/trainer/connectors/test_signal_connector.py @@ -24,7 +24,7 @@ @pytest.mark.parametrize("should_gracefully_terminate", [False, True]) -@RunIf(min_torch="1.7.0", special=True) +@RunIf(min_torch="1.7.0", skip_windows=True) def test_fault_tolerant_sig_handler(should_gracefully_terminate, tmpdir): with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": str(int(should_gracefully_terminate))}): From bea0adc3739eaeff915cd09b01c9e180e571ae8c Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 17 Sep 2021 12:30:23 +0100 Subject: [PATCH 06/19] wip --- .../trainer/connectors/signal_connector.py | 70 ++++++++++++++----- 1 file changed, 52 insertions(+), 18 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/signal_connector.py b/pytorch_lightning/trainer/connectors/signal_connector.py index feb88d8f8d506..a93c06f2ac8ad 100644 --- a/pytorch_lightning/trainer/connectors/signal_connector.py +++ b/pytorch_lightning/trainer/connectors/signal_connector.py @@ -1,32 +1,69 @@ import logging import os import signal -import sys from subprocess import call +from typing import Callable, List, Optional, Union -from pytorch_lightning.plugins.environments import SLURMEnvironment from pytorch_lightning.utilities.imports import _fault_tolerant_training log = logging.getLogger(__name__) +class HandlersCompose: + def __init__(self, signal_handlers: Union[List[Callable], Callable], user_defined_handler: Optional[Callable]): + if not isinstance(signal_handlers, list): + signal_handlers = [signal_handlers] + self.signal_handlers = signal_handlers + self.user_defined_handler = user_defined_handler + + def __call__(self): + if self.user_defined_handler: + self.user_defined_handler() + else: + for signal_handler in self.signal_handlers: + signal_handler() + + class SignalConnector: def __init__(self, trainer): self.trainer = trainer self.trainer._should_gracefully_terminate = False + self._sigusr1_handler: Optional[Callable] = None + self._sigterm_handler: Optional[Callable] = None + + @property + def sigusr1_handler(self) -> Optional[Callable]: + return self._sigusr1_handler + + @sigusr1_handler.setter + def sigusr1_handler(self, sigusr1_handler: Callable) -> None: + self._sigusr1_handler = sigusr1_handler + + @property + def sigterm_handler(self) -> Optional[Callable]: + return self._sigterm_handler + + @sigterm_handler.setter + def sigterm_handler(self, sigterm_handler: Callable) -> None: + self._sigterm_handler = sigterm_handler def register_signal_handlers(self): - cluster_env = getattr(self.trainer.training_type_plugin, "cluster_environment", None) - if isinstance(cluster_env, SLURMEnvironment): - self.register_slurm_signal_handlers() - elif _fault_tolerant_training() and not sys.platform == "win32": - self.register_fault_tolerant_handlers() + sigusr1_handlers = [] + sigterm_handlers = [] + + if _fault_tolerant_training(): + sigusr1_handlers.append(self.fault_tolerant_sigusr1_handler_fn) + + if self._is_on_slurm(): + log.info("Set SLURM handle signals.") + sigusr1_handlers.append(self.slurm_sigusr1_handler_fn) + + sigterm_handlers.append(self.sigterm_handler_fn) - def register_fault_tolerant_handlers(self): - signal.signal(signal.SIGUSR1, self.sig_fault_tolerant_handler) - signal.signal(signal.SIGTERM, self.term_handler) + signal.signal(signal.SIGUSR1, HandlersCompose(sigusr1_handlers, self.sigusr1_handler)) + signal.signal(signal.SIGTERM, HandlersCompose(sigterm_handlers, self.sigterm_handler)) - def register_slurm_signal_handlers(self): + def _is_on_slurm(self) -> bool: # see if we're using slurm (not interactive) on_slurm = False try: @@ -37,12 +74,9 @@ def register_slurm_signal_handlers(self): except Exception: pass - if on_slurm: - log.info("Set SLURM handle signals.") - signal.signal(signal.SIGUSR1, self.sig_slurm_handler) - signal.signal(signal.SIGTERM, self.term_handler) + return on_slurm - def sig_slurm_handler(self, signum, frame): # pragma: no-cover + def slurm_sigusr1_handler_fn(self, signum, frame): # pragma: no-cover if self.trainer.is_global_zero: # save weights log.info("handling SIGUSR1") @@ -72,8 +106,8 @@ def sig_slurm_handler(self, signum, frame): # pragma: no-cover # close experiment to avoid issues self.trainer.logger.close() - def sig_fault_tolerant_handler(self, signum, frame): # pragma: no-cover + def fault_tolerant_sigusr1_handler_fn(self, signum, frame): # pragma: no-cover self.trainer._should_gracefully_terminate = True - def term_handler(self, signum, frame): # pragma: no-cover + def sigterm_handler_fn(self, signum, frame): # pragma: no-cover log.info("bypassing sigterm") From d685bc3bcec0b93f449bb20c529694e9b38bf29a Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 17 Sep 2021 12:31:44 +0100 Subject: [PATCH 07/19] update --- pytorch_lightning/trainer/connectors/signal_connector.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/signal_connector.py b/pytorch_lightning/trainer/connectors/signal_connector.py index a93c06f2ac8ad..6d2e74cd80897 100644 --- a/pytorch_lightning/trainer/connectors/signal_connector.py +++ b/pytorch_lightning/trainer/connectors/signal_connector.py @@ -16,12 +16,12 @@ def __init__(self, signal_handlers: Union[List[Callable], Callable], user_define self.signal_handlers = signal_handlers self.user_defined_handler = user_defined_handler - def __call__(self): + def __call__(self, signum, frame): if self.user_defined_handler: - self.user_defined_handler() + self.user_defined_handler(signum, frame) else: for signal_handler in self.signal_handlers: - signal_handler() + signal_handler(signum, frame) class SignalConnector: From dd7002ff8b019e1ee002eabb6325bb5e6b2c39d3 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 17 Sep 2021 12:32:55 +0100 Subject: [PATCH 08/19] update --- .../trainer/connectors/signal_connector.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/signal_connector.py b/pytorch_lightning/trainer/connectors/signal_connector.py index 6d2e74cd80897..cb95082f9fd78 100644 --- a/pytorch_lightning/trainer/connectors/signal_connector.py +++ b/pytorch_lightning/trainer/connectors/signal_connector.py @@ -10,18 +10,14 @@ class HandlersCompose: - def __init__(self, signal_handlers: Union[List[Callable], Callable], user_defined_handler: Optional[Callable]): + def __init__(self, signal_handlers: Union[List[Callable], Callable]): if not isinstance(signal_handlers, list): signal_handlers = [signal_handlers] self.signal_handlers = signal_handlers - self.user_defined_handler = user_defined_handler def __call__(self, signum, frame): - if self.user_defined_handler: - self.user_defined_handler(signum, frame) - else: - for signal_handler in self.signal_handlers: - signal_handler(signum, frame) + for signal_handler in self.signal_handlers: + signal_handler(signum, frame) class SignalConnector: @@ -60,8 +56,8 @@ def register_signal_handlers(self): sigterm_handlers.append(self.sigterm_handler_fn) - signal.signal(signal.SIGUSR1, HandlersCompose(sigusr1_handlers, self.sigusr1_handler)) - signal.signal(signal.SIGTERM, HandlersCompose(sigterm_handlers, self.sigterm_handler)) + signal.signal(signal.SIGUSR1, HandlersCompose(self.sigusr1_handler or sigusr1_handlers)) + signal.signal(signal.SIGTERM, HandlersCompose(self.sigterm_handler or sigterm_handlers)) def _is_on_slurm(self) -> bool: # see if we're using slurm (not interactive) From 82d319297297b2d011d005eede7d86b695dcabe1 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 17 Sep 2021 12:35:18 +0100 Subject: [PATCH 09/19] update --- .../trainer/connectors/signal_connector.py | 10 +++++----- tests/trainer/connectors/test_signal_connector.py | 10 +++++----- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/signal_connector.py b/pytorch_lightning/trainer/connectors/signal_connector.py index cb95082f9fd78..4031623a87b43 100644 --- a/pytorch_lightning/trainer/connectors/signal_connector.py +++ b/pytorch_lightning/trainer/connectors/signal_connector.py @@ -21,11 +21,11 @@ def __call__(self, signum, frame): class SignalConnector: - def __init__(self, trainer): + def __init__(self, trainer, sigusr1_handler: Optional[Callable] = None, sigterm_handler: Optional[Callable] = None): self.trainer = trainer - self.trainer._should_gracefully_terminate = False - self._sigusr1_handler: Optional[Callable] = None - self._sigterm_handler: Optional[Callable] = None + self.trainer._terminate_gracefully = False + self._sigusr1_handler = sigusr1_handler + self._sigterm_handler = sigterm_handler @property def sigusr1_handler(self) -> Optional[Callable]: @@ -103,7 +103,7 @@ def slurm_sigusr1_handler_fn(self, signum, frame): # pragma: no-cover self.trainer.logger.close() def fault_tolerant_sigusr1_handler_fn(self, signum, frame): # pragma: no-cover - self.trainer._should_gracefully_terminate = True + self.trainer._terminate_gracefully = True def sigterm_handler_fn(self, signum, frame): # pragma: no-cover log.info("bypassing sigterm") diff --git a/tests/trainer/connectors/test_signal_connector.py b/tests/trainer/connectors/test_signal_connector.py index 17d0b9f9516af..2d0db835416f5 100644 --- a/tests/trainer/connectors/test_signal_connector.py +++ b/tests/trainer/connectors/test_signal_connector.py @@ -23,15 +23,15 @@ from tests.helpers.runif import RunIf -@pytest.mark.parametrize("should_gracefully_terminate", [False, True]) +@pytest.mark.parametrize("terminate_gracefully", [False, True]) @RunIf(min_torch="1.7.0", skip_windows=True) -def test_fault_tolerant_sig_handler(should_gracefully_terminate, tmpdir): +def test_fault_tolerant_sig_handler(terminate_gracefully, tmpdir): - with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": str(int(should_gracefully_terminate))}): + with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": str(int(terminate_gracefully))}): class TestModel(BoringModel): def training_step(self, batch, batch_idx): - if should_gracefully_terminate and self.trainer.current_epoch == 1 and batch_idx == 1: + if terminate_gracefully and self.trainer.current_epoch == 1 and batch_idx == 1: os.kill(os.getpid(), signal.SIGUSR1) sleep(0.1) return super().training_step(batch, batch_idx) @@ -39,4 +39,4 @@ def training_step(self, batch, batch_idx): model = TestModel() trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, limit_train_batches=2, limit_val_batches=2) trainer.fit(model) - assert trainer._should_gracefully_terminate == should_gracefully_terminate + assert trainer._terminate_gracefully == terminate_gracefully From dc622361d1de6fa1836a28d49a6d7235d5d08a6e Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 17 Sep 2021 13:08:21 +0100 Subject: [PATCH 10/19] update --- .../trainer/connectors/signal_connector.py | 16 ++++++++++++++-- .../connectors/test_signal_connector.py | 19 +++++++++++++++---- 2 files changed, 29 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/signal_connector.py b/pytorch_lightning/trainer/connectors/signal_connector.py index 4031623a87b43..3e8db713b5371 100644 --- a/pytorch_lightning/trainer/connectors/signal_connector.py +++ b/pytorch_lightning/trainer/connectors/signal_connector.py @@ -1,9 +1,12 @@ import logging import os import signal +from signal import Signals from subprocess import call from typing import Callable, List, Optional, Union +import _signal + from pytorch_lightning.utilities.imports import _fault_tolerant_training log = logging.getLogger(__name__) @@ -56,8 +59,11 @@ def register_signal_handlers(self): sigterm_handlers.append(self.sigterm_handler_fn) - signal.signal(signal.SIGUSR1, HandlersCompose(self.sigusr1_handler or sigusr1_handlers)) - signal.signal(signal.SIGTERM, HandlersCompose(self.sigterm_handler or sigterm_handlers)) + if not self._has_already_handler(signal.SIGUSR1): + signal.signal(signal.SIGUSR1, HandlersCompose(self.sigusr1_handler or sigusr1_handlers)) + + if not self._has_already_handler(signal.SIGTERM): + signal.signal(signal.SIGTERM, HandlersCompose(self.sigterm_handler or sigterm_handlers)) def _is_on_slurm(self) -> bool: # see if we're using slurm (not interactive) @@ -107,3 +113,9 @@ def fault_tolerant_sigusr1_handler_fn(self, signum, frame): # pragma: no-cover def sigterm_handler_fn(self, signum, frame): # pragma: no-cover log.info("bypassing sigterm") + + def _has_already_handler(self, signal: Signals) -> bool: + try: + return isinstance(_signal.getsignal(signal), Callable) + except AttributeError: + return False diff --git a/tests/trainer/connectors/test_signal_connector.py b/tests/trainer/connectors/test_signal_connector.py index 2d0db835416f5..1ec62f3d3082f 100644 --- a/tests/trainer/connectors/test_signal_connector.py +++ b/tests/trainer/connectors/test_signal_connector.py @@ -23,20 +23,31 @@ from tests.helpers.runif import RunIf +@pytest.mark.parametrize("register_handler", [False, True]) @pytest.mark.parametrize("terminate_gracefully", [False, True]) @RunIf(min_torch="1.7.0", skip_windows=True) -def test_fault_tolerant_sig_handler(terminate_gracefully, tmpdir): +def test_fault_tolerant_sig_handler(register_handler, terminate_gracefully, tmpdir): + + # hack to reset the signal + signal.signal(signal.SIGUSR1, 0) + + if register_handler: + + def handler(*_): + pass + + signal.signal(signal.SIGUSR1, handler) with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": str(int(terminate_gracefully))}): class TestModel(BoringModel): def training_step(self, batch, batch_idx): - if terminate_gracefully and self.trainer.current_epoch == 1 and batch_idx == 1: + if terminate_gracefully or register_handler: os.kill(os.getpid(), signal.SIGUSR1) sleep(0.1) return super().training_step(batch, batch_idx) model = TestModel() - trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, limit_train_batches=2, limit_val_batches=2) + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_train_batches=2, limit_val_batches=0) trainer.fit(model) - assert trainer._terminate_gracefully == terminate_gracefully + assert trainer._terminate_gracefully == (False if register_handler else terminate_gracefully) From 5f677d43d378e4cf98976bbf058a6240f165098a Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 17 Sep 2021 16:10:11 +0100 Subject: [PATCH 11/19] update --- pytorch_lightning/trainer/connectors/signal_connector.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/signal_connector.py b/pytorch_lightning/trainer/connectors/signal_connector.py index 3e8db713b5371..3603113b2ea38 100644 --- a/pytorch_lightning/trainer/connectors/signal_connector.py +++ b/pytorch_lightning/trainer/connectors/signal_connector.py @@ -5,8 +5,6 @@ from subprocess import call from typing import Callable, List, Optional, Union -import _signal - from pytorch_lightning.utilities.imports import _fault_tolerant_training log = logging.getLogger(__name__) @@ -114,8 +112,8 @@ def fault_tolerant_sigusr1_handler_fn(self, signum, frame): # pragma: no-cover def sigterm_handler_fn(self, signum, frame): # pragma: no-cover log.info("bypassing sigterm") - def _has_already_handler(self, signal: Signals) -> bool: + def _has_already_handler(self, sig: Signals) -> bool: try: - return isinstance(_signal.getsignal(signal), Callable) + return isinstance(signal.getsignal(sig), Callable) except AttributeError: return False From 3e522a32edf4d25a2a513192f11d14533e904f70 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 17 Sep 2021 16:16:19 +0100 Subject: [PATCH 12/19] update --- pyproject.toml | 1 + .../trainer/connectors/signal_connector.py | 46 ++++++------------- 2 files changed, 16 insertions(+), 31 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d205af7f0a1f3..19110ea731996 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,6 +67,7 @@ module = [ "pytorch_lightning.loops.evaluation_loop", "pytorch_lightning.trainer.connectors.checkpoint_connector", "pytorch_lightning.trainer.connectors.logger_connector.*", + "pytorch_lightning.trainer.connectors.signal_connector", "pytorch_lightning.trainer.progress", "pytorch_lightning.tuner.auto_gpu_select", "pytorch_lightning.utilities.apply_func", diff --git a/pytorch_lightning/trainer/connectors/signal_connector.py b/pytorch_lightning/trainer/connectors/signal_connector.py index 3603113b2ea38..54b15bff3b597 100644 --- a/pytorch_lightning/trainer/connectors/signal_connector.py +++ b/pytorch_lightning/trainer/connectors/signal_connector.py @@ -3,8 +3,10 @@ import signal from signal import Signals from subprocess import call -from typing import Callable, List, Optional, Union +from types import FrameType +from typing import Callable, List, Union +import pytorch_lightning as pl from pytorch_lightning.utilities.imports import _fault_tolerant_training log = logging.getLogger(__name__) @@ -22,31 +24,13 @@ def __call__(self, signum, frame): class SignalConnector: - def __init__(self, trainer, sigusr1_handler: Optional[Callable] = None, sigterm_handler: Optional[Callable] = None): + def __init__(self, trainer: "pl.Trainer"): self.trainer = trainer - self.trainer._terminate_gracefully = False - self._sigusr1_handler = sigusr1_handler - self._sigterm_handler = sigterm_handler + self.trainer._terminate_gracefully: bool = False - @property - def sigusr1_handler(self) -> Optional[Callable]: - return self._sigusr1_handler - - @sigusr1_handler.setter - def sigusr1_handler(self, sigusr1_handler: Callable) -> None: - self._sigusr1_handler = sigusr1_handler - - @property - def sigterm_handler(self) -> Optional[Callable]: - return self._sigterm_handler - - @sigterm_handler.setter - def sigterm_handler(self, sigterm_handler: Callable) -> None: - self._sigterm_handler = sigterm_handler - - def register_signal_handlers(self): - sigusr1_handlers = [] - sigterm_handlers = [] + def register_signal_handlers(self) -> None: + sigusr1_handlers: List[Callable] = [] + sigterm_handlers: List[Callable] = [] if _fault_tolerant_training(): sigusr1_handlers.append(self.fault_tolerant_sigusr1_handler_fn) @@ -58,10 +42,10 @@ def register_signal_handlers(self): sigterm_handlers.append(self.sigterm_handler_fn) if not self._has_already_handler(signal.SIGUSR1): - signal.signal(signal.SIGUSR1, HandlersCompose(self.sigusr1_handler or sigusr1_handlers)) + signal.signal(signal.SIGUSR1, HandlersCompose(sigusr1_handlers)) if not self._has_already_handler(signal.SIGTERM): - signal.signal(signal.SIGTERM, HandlersCompose(self.sigterm_handler or sigterm_handlers)) + signal.signal(signal.SIGTERM, HandlersCompose(sigterm_handlers)) def _is_on_slurm(self) -> bool: # see if we're using slurm (not interactive) @@ -76,7 +60,7 @@ def _is_on_slurm(self) -> bool: return on_slurm - def slurm_sigusr1_handler_fn(self, signum, frame): # pragma: no-cover + def slurm_sigusr1_handler_fn(self, signum: Signals, frame: FrameType) -> None: if self.trainer.is_global_zero: # save weights log.info("handling SIGUSR1") @@ -106,14 +90,14 @@ def slurm_sigusr1_handler_fn(self, signum, frame): # pragma: no-cover # close experiment to avoid issues self.trainer.logger.close() - def fault_tolerant_sigusr1_handler_fn(self, signum, frame): # pragma: no-cover + def fault_tolerant_sigusr1_handler_fn(self, signum: Signals, frame: FrameType): self.trainer._terminate_gracefully = True - def sigterm_handler_fn(self, signum, frame): # pragma: no-cover + def sigterm_handler_fn(self, signum: Signals, frame: FrameType) -> None: log.info("bypassing sigterm") - def _has_already_handler(self, sig: Signals) -> bool: + def _has_already_handler(self, signum: Signals) -> bool: try: - return isinstance(signal.getsignal(sig), Callable) + return isinstance(signal.getsignal(signum), Callable) except AttributeError: return False From 11b2b93528305d7e95408adf886f7b0fff1a5703 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 17 Sep 2021 16:24:37 +0100 Subject: [PATCH 13/19] update --- .../trainer/connectors/signal_connector.py | 39 +++++++++++-------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/signal_connector.py b/pytorch_lightning/trainer/connectors/signal_connector.py index 54b15bff3b597..65d3cea5ddb08 100644 --- a/pytorch_lightning/trainer/connectors/signal_connector.py +++ b/pytorch_lightning/trainer/connectors/signal_connector.py @@ -1,6 +1,7 @@ import logging import os import signal +import sys from signal import Signals from subprocess import call from types import FrameType @@ -41,24 +42,12 @@ def register_signal_handlers(self) -> None: sigterm_handlers.append(self.sigterm_handler_fn) - if not self._has_already_handler(signal.SIGUSR1): - signal.signal(signal.SIGUSR1, HandlersCompose(sigusr1_handlers)) + if not self._is_on_windows(): + if not self._has_already_handler(signal.SIGUSR1): + signal.signal(signal.SIGUSR1, HandlersCompose(sigusr1_handlers)) - if not self._has_already_handler(signal.SIGTERM): - signal.signal(signal.SIGTERM, HandlersCompose(sigterm_handlers)) - - def _is_on_slurm(self) -> bool: - # see if we're using slurm (not interactive) - on_slurm = False - try: - job_name = os.environ["SLURM_JOB_NAME"] - if job_name != "bash": - on_slurm = True - # todo: specify the possible exception - except Exception: - pass - - return on_slurm + if not self._has_already_handler(signal.SIGTERM): + signal.signal(signal.SIGTERM, HandlersCompose(sigterm_handlers)) def slurm_sigusr1_handler_fn(self, signum: Signals, frame: FrameType) -> None: if self.trainer.is_global_zero: @@ -96,6 +85,22 @@ def fault_tolerant_sigusr1_handler_fn(self, signum: Signals, frame: FrameType): def sigterm_handler_fn(self, signum: Signals, frame: FrameType) -> None: log.info("bypassing sigterm") + def _is_on_slurm(self) -> bool: + # see if we're using slurm (not interactive) + on_slurm = False + try: + job_name = os.environ["SLURM_JOB_NAME"] + if job_name != "bash": + on_slurm = True + # todo: specify the possible exception + except Exception: + pass + + return on_slurm + + def _is_on_windows(self) -> bool: + return sys.platform == "win32" + def _has_already_handler(self, signum: Signals) -> bool: try: return isinstance(signal.getsignal(signum), Callable) From a8a7ac5245cf63e7dbcdbf87ea3b2a1a9c2f50ae Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 17 Sep 2021 16:40:25 +0100 Subject: [PATCH 14/19] resolve mypy --- .../trainer/connectors/signal_connector.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/signal_connector.py b/pytorch_lightning/trainer/connectors/signal_connector.py index 65d3cea5ddb08..fe24343e01028 100644 --- a/pytorch_lightning/trainer/connectors/signal_connector.py +++ b/pytorch_lightning/trainer/connectors/signal_connector.py @@ -4,7 +4,7 @@ import sys from signal import Signals from subprocess import call -from types import FrameType +from types import FrameType, FunctionType from typing import Callable, List, Union import pytorch_lightning as pl @@ -19,7 +19,7 @@ def __init__(self, signal_handlers: Union[List[Callable], Callable]): signal_handlers = [signal_handlers] self.signal_handlers = signal_handlers - def __call__(self, signum, frame): + def __call__(self, signum: Signals, frame: FrameType) -> None: for signal_handler in self.signal_handlers: signal_handler(signum, frame) @@ -27,7 +27,7 @@ def __call__(self, signum, frame): class SignalConnector: def __init__(self, trainer: "pl.Trainer"): self.trainer = trainer - self.trainer._terminate_gracefully: bool = False + self.trainer._terminate_gracefully = False def register_signal_handlers(self) -> None: sigusr1_handlers: List[Callable] = [] @@ -79,7 +79,7 @@ def slurm_sigusr1_handler_fn(self, signum: Signals, frame: FrameType) -> None: # close experiment to avoid issues self.trainer.logger.close() - def fault_tolerant_sigusr1_handler_fn(self, signum: Signals, frame: FrameType): + def fault_tolerant_sigusr1_handler_fn(self, signum: Signals, frame: FrameType) -> None: self.trainer._terminate_gracefully = True def sigterm_handler_fn(self, signum: Signals, frame: FrameType) -> None: @@ -103,6 +103,6 @@ def _is_on_windows(self) -> bool: def _has_already_handler(self, signum: Signals) -> bool: try: - return isinstance(signal.getsignal(signum), Callable) + return isinstance(signal.getsignal(signum), FunctionType) except AttributeError: return False From 1af0fdd3da793eefdf0631396cc6c12cd77b5859 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Fri, 17 Sep 2021 17:32:24 +0100 Subject: [PATCH 15/19] Update CHANGELOG.md Co-authored-by: Sean Naren --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1b1f6c483a5b2..9096aefe1a17d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -65,7 +65,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Added partial support for global random state fault-tolerance in map-style datasets ([#8950](https://github.com/PyTorchLightning/pytorch-lightning/pull/8950)) * Converted state to tuple explicitly when setting Python random state ([#9401](https://github.com/PyTorchLightning/pytorch-lightning/pull/9401)) * Added support for restarting an optimizer loop (multiple optimizers) ([#9537](https://github.com/PyTorchLightning/pytorch-lightning/pull/9537)) - * Added a mechanism to detect a signal as been sent so the Trainer can gracefully exit ([#9566](https://github.com/PyTorchLightning/pytorch-lightning/pull/9566)) + * Added mechanism to detect a signal has been sent so the Trainer can gracefully exit ([#9566](https://github.com/PyTorchLightning/pytorch-lightning/pull/9566)) - Checkpoint saving & loading extensibility: From 5034bd998b00a0a83b909ac0aac4613828eea7ab Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 17 Sep 2021 17:33:24 +0100 Subject: [PATCH 16/19] update --- pytorch_lightning/trainer/connectors/signal_connector.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/trainer/connectors/signal_connector.py b/pytorch_lightning/trainer/connectors/signal_connector.py index fe24343e01028..b46db8579de23 100644 --- a/pytorch_lightning/trainer/connectors/signal_connector.py +++ b/pytorch_lightning/trainer/connectors/signal_connector.py @@ -42,6 +42,7 @@ def register_signal_handlers(self) -> None: sigterm_handlers.append(self.sigterm_handler_fn) + # signal.SIGUSR1 doesn't seem available on windows if not self._is_on_windows(): if not self._has_already_handler(signal.SIGUSR1): signal.signal(signal.SIGUSR1, HandlersCompose(sigusr1_handlers)) From 9fc7d0ac20483d859d2f38f3b1f5ab9fd7c48c13 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 17 Sep 2021 18:15:36 +0100 Subject: [PATCH 17/19] resolve mypy --- pytorch_lightning/trainer/connectors/checkpoint_connector.py | 2 +- pytorch_lightning/trainer/connectors/signal_connector.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 5376a81b658d6..516b6c03c0a27 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -274,7 +274,7 @@ def restore_lr_schedulers(self) -> None: # PRIVATE OPS # ---------------------------------- - def hpc_save(self, folderpath: str, logger: LightningLoggerBase) -> str: + def hpc_save(self, folderpath: str, logger: Optional[LightningLoggerBase]) -> str: # make sure the checkpoint folder exists folderpath = str(folderpath) # because the tests pass a path object fs = get_filesystem(folderpath) diff --git a/pytorch_lightning/trainer/connectors/signal_connector.py b/pytorch_lightning/trainer/connectors/signal_connector.py index b46db8579de23..2db3130c4c9f3 100644 --- a/pytorch_lightning/trainer/connectors/signal_connector.py +++ b/pytorch_lightning/trainer/connectors/signal_connector.py @@ -78,7 +78,8 @@ def slurm_sigusr1_handler_fn(self, signum: Signals, frame: FrameType) -> None: log.warning("requeue failed...") # close experiment to avoid issues - self.trainer.logger.close() + if self.trainer.logger: + self.trainer.logger.finalize() def fault_tolerant_sigusr1_handler_fn(self, signum: Signals, frame: FrameType) -> None: self.trainer._terminate_gracefully = True From c100d1e38aa381b8fecb0533f67075f050ed41ef Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 17 Sep 2021 18:54:38 +0100 Subject: [PATCH 18/19] update --- pytorch_lightning/trainer/connectors/checkpoint_connector.py | 2 +- pytorch_lightning/trainer/connectors/signal_connector.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 516b6c03c0a27..5275499ebc79b 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -281,7 +281,7 @@ def hpc_save(self, folderpath: str, logger: Optional[LightningLoggerBase]) -> st fs.makedirs(folderpath, exist_ok=True) # save logger to make sure we get all the metrics - logger.save() + logger.finalize("finished") max_suffix = self.max_ckpt_version_in_folder(folderpath) ckpt_number = (max_suffix if max_suffix is not None else 0) + 1 diff --git a/pytorch_lightning/trainer/connectors/signal_connector.py b/pytorch_lightning/trainer/connectors/signal_connector.py index 2db3130c4c9f3..8e21ffc6dd44c 100644 --- a/pytorch_lightning/trainer/connectors/signal_connector.py +++ b/pytorch_lightning/trainer/connectors/signal_connector.py @@ -79,7 +79,7 @@ def slurm_sigusr1_handler_fn(self, signum: Signals, frame: FrameType) -> None: # close experiment to avoid issues if self.trainer.logger: - self.trainer.logger.finalize() + self.trainer.logger.finalize("finished") def fault_tolerant_sigusr1_handler_fn(self, signum: Signals, frame: FrameType) -> None: self.trainer._terminate_gracefully = True From 7216a39e31ed20bfd4513e00b6d7cb2304194f0b Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 17 Sep 2021 18:58:46 +0100 Subject: [PATCH 19/19] update --- pytorch_lightning/trainer/connectors/checkpoint_connector.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 5275499ebc79b..b750b0f81b26f 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -281,7 +281,8 @@ def hpc_save(self, folderpath: str, logger: Optional[LightningLoggerBase]) -> st fs.makedirs(folderpath, exist_ok=True) # save logger to make sure we get all the metrics - logger.finalize("finished") + if logger: + logger.finalize("finished") max_suffix = self.max_ckpt_version_in_folder(folderpath) ckpt_number = (max_suffix if max_suffix is not None else 0) + 1