Skip to content

Commit 73391c6

Browse files
authored
Merge branch 'master' into feature/simplify_store
2 parents 8509926 + 7e13eb7 commit 73391c6

File tree

7 files changed

+135
-45
lines changed

7 files changed

+135
-45
lines changed

src/lightning/fabric/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
112112
- Allowed accessing rank information in the main process before processes are launched when using the `XLAStrategy` ([#18194](https://github.com/Lightning-AI/lightning/pull/18194))
113113

114114

115+
- Added automatic process cleanup to avoid zombie child processes and stalls when exceptions are raised ([#18218](https://github.com/Lightning-AI/lightning/pull/18218))
116+
117+
115118
### Changed
116119

117120
- Allow using iterable-style datasets with TPUs ([#17331](https://github.com/Lightning-AI/lightning/pull/17331))

src/lightning/fabric/strategies/launchers/base.py

Lines changed: 0 additions & 35 deletions
This file was deleted.

src/lightning/fabric/strategies/launchers/subprocess_script.py

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,21 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import logging
1415
import os
16+
import signal
1517
import subprocess
1618
import sys
17-
from typing import Any, Callable, Optional, Sequence, Tuple
19+
import time
20+
from threading import Thread
21+
from typing import Any, Callable, List, Optional, Sequence, Tuple
1822

1923
from lightning_utilities.core.imports import RequirementCache
2024

2125
from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
2226
from lightning.fabric.strategies.launchers.launcher import _Launcher
2327

28+
_logger = logging.getLogger(__name__)
2429
_HYDRA_AVAILABLE = RequirementCache("hydra-core")
2530

2631

@@ -71,6 +76,7 @@ def __init__(
7176
self.cluster_environment = cluster_environment
7277
self.num_processes = num_processes
7378
self.num_nodes = num_nodes
79+
self.procs: List[subprocess.Popen] = [] # launched child subprocesses, does not include the launcher
7480

7581
@property
7682
def is_interactive_compatible(self) -> bool:
@@ -87,6 +93,7 @@ def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any:
8793
"""
8894
if not self.cluster_environment.creates_processes_externally:
8995
self._call_children_scripts()
96+
_launch_process_observer(self.procs)
9097
return function(*args, **kwargs)
9198

9299
def _call_children_scripts(self) -> None:
@@ -122,9 +129,13 @@ def _call_children_scripts(self) -> None:
122129
command, cwd = _hydra_subprocess_cmd(local_rank=local_rank)
123130
else:
124131
command = _basic_subprocess_cmd()
125-
subprocess.Popen(command, env=env_copy, cwd=cwd)
132+
133+
proc = subprocess.Popen(command, env=env_copy, cwd=cwd)
134+
self.procs.append(proc)
126135

127136
def _check_can_spawn_children(self) -> None:
137+
if len(self.procs) > 0:
138+
raise RuntimeError("The launcher can only create subprocesses once.")
128139
if self.cluster_environment.local_rank() != 0:
129140
raise RuntimeError(
130141
"Lightning attempted to launch new distributed processes with `local_rank > 0`. This should not happen."
@@ -159,3 +170,53 @@ def _hydra_subprocess_cmd(local_rank: int) -> Tuple[Sequence[str], str]:
159170
# Set output_subdir null since we don't want different subprocesses trying to write to config.yaml
160171
command += [f"hydra.run.dir={rundir}", f"hydra.job.name=train_ddp_process_{local_rank}", "hydra.output_subdir=null"]
161172
return command, cwd
173+
174+
175+
def _launch_process_observer(child_processes: List[subprocess.Popen]) -> None:
176+
"""Launches a thread that runs along the main process and monitors the health of all processes."""
177+
monitor_thread = Thread(
178+
target=_ChildProcessObserver(child_processes=child_processes, main_pid=os.getpid()),
179+
daemon=True, # thread stops if the main process exits
180+
)
181+
monitor_thread.start()
182+
183+
184+
class _ChildProcessObserver:
185+
def __init__(self, main_pid: int, child_processes: List[subprocess.Popen], sleep_period: int = 5) -> None:
186+
self._main_pid = main_pid
187+
self._child_processes = child_processes
188+
self._sleep_period = sleep_period
189+
# Note: SIGTERM is not aggressive enough to terminate processes hanging in collectives
190+
self._termination_signal = signal.SIGTERM if sys.platform == "win32" else signal.SIGKILL
191+
self._finished = False
192+
193+
def __call__(self) -> None:
194+
while not self._finished:
195+
time.sleep(self._sleep_period)
196+
self._finished = self._run()
197+
198+
def _run(self) -> bool:
199+
"""Runs once over all child processes to check whether they are still running."""
200+
for proc in self._child_processes:
201+
proc.poll()
202+
203+
return_codes = [proc.returncode for proc in self._child_processes]
204+
if all(return_code == 0 for return_code in return_codes):
205+
return True
206+
207+
for proc in self._child_processes:
208+
if proc.returncode:
209+
_logger.info(
210+
f"Child process with PID {proc.pid} terminated with code {proc.returncode}."
211+
f" Forcefully terminating all other processes to avoid zombies 🧟"
212+
)
213+
self._terminate_all()
214+
return True
215+
216+
return False
217+
218+
def _terminate_all(self) -> None:
219+
"""Terminates the main process and all its children."""
220+
for p in self._child_processes:
221+
p.send_signal(self._termination_signal)
222+
os.kill(self._main_pid, self._termination_signal)

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
9494
- Added support for true half-precision training via `Trainer(precision="16-true"|"bf16-true")` ([#18193](https://github.com/Lightning-AI/lightning/pull/18193))
9595

9696

97+
- Added automatic process cleanup to avoid zombie child processes and stalls when exceptions are raised ([#18218](https://github.com/Lightning-AI/lightning/pull/18218))
98+
99+
97100
### Changed
98101

99102
- Removed the limitation to call `self.trainer.model.parameters()` in `LightningModule.configure_optimizers()` ([#17309](https://github.com/Lightning-AI/lightning/pull/17309))

src/lightning/pytorch/strategies/launchers/subprocess_script.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@
2020

2121
import lightning.pytorch as pl
2222
from lightning.fabric.plugins import ClusterEnvironment
23-
from lightning.fabric.strategies.launchers.subprocess_script import _basic_subprocess_cmd, _hydra_subprocess_cmd
23+
from lightning.fabric.strategies.launchers.subprocess_script import (
24+
_basic_subprocess_cmd,
25+
_hydra_subprocess_cmd,
26+
_launch_process_observer,
27+
)
2428
from lightning.pytorch.strategies.launchers.launcher import _Launcher
2529
from lightning.pytorch.trainer.connectors.signal_connector import _SIGNUM
2630

@@ -70,7 +74,7 @@ def __init__(self, cluster_environment: ClusterEnvironment, num_processes: int,
7074
self.cluster_environment = cluster_environment
7175
self.num_processes = num_processes
7276
self.num_nodes = num_nodes
73-
self.procs: List[subprocess.Popen] = [] # launched subprocesses. does not include the launcher
77+
self.procs: List[subprocess.Popen] = [] # launched child subprocesses, does not include the launcher
7478

7579
@property
7680
def is_interactive_compatible(self) -> bool:
@@ -88,6 +92,7 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"]
8892
"""
8993
if not self.cluster_environment.creates_processes_externally:
9094
self._call_children_scripts()
95+
_launch_process_observer(self.procs)
9196
return function(*args, **kwargs)
9297

9398
def kill(self, signum: _SIGNUM) -> None:
@@ -134,6 +139,8 @@ def _call_children_scripts(self) -> None:
134139
self.procs.append(new_process)
135140

136141
def _check_can_spawn_children(self) -> None:
142+
if len(self.procs) > 0:
143+
raise RuntimeError("The launcher can only create subprocesses once.")
137144
if self.cluster_environment.local_rank() != 0:
138145
raise RuntimeError(
139146
"Lightning attempted to launch new distributed processes with `local_rank > 0`. This should not happen."

tests/tests_fabric/strategies/launchers/test_subprocess_script.py

Lines changed: 54 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,19 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import os
15+
import signal
16+
import sys
1517
from unittest import mock
1618
from unittest.mock import ANY, Mock
1719

1820
import pytest
1921

2022
import lightning.fabric
21-
from lightning.fabric.strategies.launchers.subprocess_script import _HYDRA_AVAILABLE, _SubprocessScriptLauncher
23+
from lightning.fabric.strategies.launchers.subprocess_script import (
24+
_ChildProcessObserver,
25+
_HYDRA_AVAILABLE,
26+
_SubprocessScriptLauncher,
27+
)
2228

2329

2430
def test_subprocess_script_launcher_interactive_compatible():
@@ -27,17 +33,24 @@ def test_subprocess_script_launcher_interactive_compatible():
2733

2834

2935
@mock.patch("lightning.fabric.strategies.launchers.subprocess_script.subprocess.Popen")
30-
def test_subprocess_script_launcher_error_launching_on_non_zero_rank(popen_mock):
36+
@mock.patch("lightning.fabric.strategies.launchers.subprocess_script.Thread")
37+
def test_subprocess_script_launcher_can_launch(*_):
3138
cluster_env = Mock()
3239
cluster_env.creates_processes_externally = False
3340
cluster_env.local_rank.return_value = 1
3441
launcher = _SubprocessScriptLauncher(cluster_env, num_processes=2, num_nodes=1)
42+
3543
with pytest.raises(RuntimeError, match="attempted to launch new distributed processes with `local_rank > 0`"):
3644
launcher.launch(Mock())
3745

46+
launcher.procs = [Mock()] # there are already processes running
47+
with pytest.raises(RuntimeError, match="The launcher can only create subprocesses once"):
48+
launcher.launch(Mock())
49+
3850

3951
@mock.patch("lightning.fabric.strategies.launchers.subprocess_script.subprocess.Popen")
40-
def test_subprocess_script_launcher_external_processes(popen_mock):
52+
@mock.patch("lightning.fabric.strategies.launchers.subprocess_script.Thread")
53+
def test_subprocess_script_launcher_external_processes(_, popen_mock):
4154
cluster_env = Mock()
4255
cluster_env.creates_processes_externally = True
4356
function = Mock()
@@ -48,7 +61,8 @@ def test_subprocess_script_launcher_external_processes(popen_mock):
4861

4962

5063
@mock.patch("lightning.fabric.strategies.launchers.subprocess_script.subprocess.Popen")
51-
def test_subprocess_script_launcher_launch_processes(popen_mock):
64+
@mock.patch("lightning.fabric.strategies.launchers.subprocess_script.Thread")
65+
def test_subprocess_script_launcher_launch_processes(_, popen_mock):
5266
cluster_env = Mock()
5367
cluster_env.creates_processes_externally = False
5468
cluster_env.local_rank.return_value = 0
@@ -80,7 +94,8 @@ def test_subprocess_script_launcher_launch_processes(popen_mock):
8094

8195
@pytest.mark.skipif(not _HYDRA_AVAILABLE, reason="hydra-core is required")
8296
@mock.patch("lightning.fabric.strategies.launchers.subprocess_script.subprocess.Popen")
83-
def test_subprocess_script_launcher_hydra_in_use(popen_mock, monkeypatch):
97+
@mock.patch("lightning.fabric.strategies.launchers.subprocess_script.Thread")
98+
def test_subprocess_script_launcher_hydra_in_use(_, popen_mock, monkeypatch):
8499
basic_command = Mock(return_value="basic_command")
85100
hydra_command = Mock(return_value=("hydra_command", "hydra_cwd"))
86101
monkeypatch.setattr(lightning.fabric.strategies.launchers.subprocess_script, "_basic_subprocess_cmd", basic_command)
@@ -121,3 +136,37 @@ def simulate_launch():
121136
simulate_launch()
122137
popen_mock.assert_called_with("hydra_command", env=ANY, cwd="hydra_cwd")
123138
popen_mock.reset_mock()
139+
140+
141+
@mock.patch("lightning.fabric.strategies.launchers.subprocess_script.os.kill")
142+
@mock.patch("lightning.fabric.strategies.launchers.subprocess_script.time.sleep")
143+
def test_child_process_observer(sleep_mock, os_kill_mock):
144+
# Case 1: All processes are running and did not exit yet
145+
processes = [Mock(returncode=None), Mock(returncode=None)]
146+
observer = _ChildProcessObserver(main_pid=1234, child_processes=processes)
147+
finished = observer._run() # call _run() directly to simulate while loop
148+
assert not finished
149+
150+
# Case 2: All processes have finished with exit code 0 (success)
151+
processes = [Mock(returncode=0), Mock(returncode=0)]
152+
observer = _ChildProcessObserver(main_pid=1234, child_processes=processes)
153+
finished = observer._run() # call _run() directly to simulate while loop
154+
assert finished
155+
156+
# Case 3: One process has finished with exit code 1 (failure)
157+
processes = [Mock(returncode=0), Mock(returncode=1)]
158+
observer = _ChildProcessObserver(main_pid=1234, child_processes=processes)
159+
finished = observer._run() # call _run() directly to simulate while loop
160+
assert finished
161+
expected_signal = signal.SIGTERM if sys.platform == "win32" else signal.SIGKILL
162+
processes[0].send_signal.assert_called_once_with(expected_signal)
163+
processes[1].send_signal.assert_called_once_with(expected_signal)
164+
os_kill_mock.assert_called_once_with(1234, expected_signal)
165+
166+
# The main routine stops
167+
observer = _ChildProcessObserver(main_pid=1234, child_processes=[Mock(), Mock()])
168+
observer._run = Mock()
169+
assert not observer._finished
170+
observer()
171+
assert observer._finished
172+
sleep_mock.assert_called_once_with(5)

tests/tests_pytorch/strategies/launchers/test_subprocess_script.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import subprocess
22
import sys
3+
from unittest import mock
34
from unittest.mock import Mock
45

56
import pytest
@@ -77,7 +78,8 @@ def test_ddp_with_hydra_runjob(subdir, tmp_path, monkeypatch):
7778
assert len(logs) == devices
7879

7980

80-
def test_kill():
81+
@mock.patch("lightning.fabric.strategies.launchers.subprocess_script.Thread")
82+
def test_kill(_):
8183
launcher = _SubprocessScriptLauncher(Mock(), 1, 1)
8284
proc0 = Mock(autospec=subprocess.Popen)
8385
proc1 = Mock(autospec=subprocess.Popen)

0 commit comments

Comments
 (0)