Skip to content

Commit a0ca2c8

Browse files
awaelchlicarmoccapre-commit-ci[bot]
authored
Disable memory sharing on model parameters in ddp-spawn (#18238)
Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 0d1932c commit a0ca2c8

File tree

7 files changed

+124
-6
lines changed

7 files changed

+124
-6
lines changed

src/lightning/fabric/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
201201
- Fixed an issue with `Fabric.all_reduce()` not performing an inplace operation for all backends consistently ([#18235](https://github.com/Lightning-AI/lightning/pull/18235))
202202

203203

204+
- Fixed model parameters getting shared between processes when running with `strategy="ddp_spawn"` and `accelerator="cpu"`; this has a necessary memory impact, as parameters are replicated for each process now ([#18238](https://github.com/Lightning-AI/lightning/pull/18238))
205+
206+
204207
## [2.0.5] - 2023-07-07
205208

206209
### Added

src/lightning/fabric/strategies/launchers/multiprocessing.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import itertools
1415
import os
1516
from dataclasses import dataclass
1617
from multiprocessing.queues import SimpleQueue
@@ -19,7 +20,10 @@
1920
import torch
2021
import torch.backends.cudnn
2122
import torch.multiprocessing as mp
23+
from lightning_utilities import apply_to_collection
24+
from torch.nn import Module
2225

26+
from lightning.fabric.accelerators.cpu import CPUAccelerator
2327
from lightning.fabric.strategies.launchers.launcher import _Launcher
2428
from lightning.fabric.utilities.apply_func import move_data_to_device
2529
from lightning.fabric.utilities.imports import _IS_INTERACTIVE
@@ -122,6 +126,10 @@ def _wrapping_function(
122126
) -> None:
123127
if global_states:
124128
global_states.restore()
129+
130+
if self._start_method == "spawn" and isinstance(self._strategy.accelerator, CPUAccelerator):
131+
args, kwargs = _disable_module_memory_sharing((args, kwargs))
132+
125133
os.environ["LOCAL_RANK"] = str(process_idx)
126134
results = function(*args, **kwargs)
127135

@@ -190,3 +198,21 @@ def _check_bad_cuda_fork() -> None:
190198
if _IS_INTERACTIVE:
191199
message += " You will have to restart the Python kernel."
192200
raise RuntimeError(message)
201+
202+
203+
def _disable_module_memory_sharing(data: Any) -> Any:
204+
"""Disables memory sharing on parameters and buffers of `nn.Module`s contained in the given collection.
205+
206+
Note: This is only required when running on CPU.
207+
"""
208+
# PyTorch enables memory sharing automatically on all tensors that are passed through `mp.spawn`.
209+
# For model weights and buffers, this is undesired and can lead to race conditions between processes.
210+
# Hence, we copy the tensors in the entire module to ensure it doesn't share memory with other processes.
211+
212+
@torch.no_grad()
213+
def unshare(module: Module) -> Module:
214+
for tensor in itertools.chain(module.parameters(), module.buffers()):
215+
tensor.data = tensor.data.clone()
216+
return module
217+
218+
return apply_to_collection(data, function=unshare, dtype=Module)

src/lightning/pytorch/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
218218
- Fixed an issue that could cause the `LightningOptimizer` wrapper returned by `LightningModule.optimizers()` have different internal state than the optimizer it wraps ([#18280](https://github.com/Lightning-AI/lightning/pull/18280))
219219

220220

221+
- Fixed model parameters getting shared between processes when running with `strategy="ddp_spawn"` and `accelerator="cpu"`; this has a necessary memory impact, as parameters are replicated for each process now ([#18238](https://github.com/Lightning-AI/lightning/pull/18238))
222+
221223

222224
## [2.0.5] - 2023-07-07
223225

src/lightning/pytorch/strategies/launchers/multiprocessing.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,11 @@
2727
from torch import Tensor
2828

2929
import lightning.pytorch as pl
30-
from lightning.fabric.strategies.launchers.multiprocessing import _check_bad_cuda_fork
30+
from lightning.fabric.strategies.launchers.multiprocessing import _check_bad_cuda_fork, _disable_module_memory_sharing
3131
from lightning.fabric.utilities import move_data_to_device
3232
from lightning.fabric.utilities.seed import _collect_rng_states, _set_rng_states
3333
from lightning.fabric.utilities.types import _PATH
34+
from lightning.pytorch.accelerators import CPUAccelerator
3435
from lightning.pytorch.strategies.launchers.launcher import _Launcher
3536
from lightning.pytorch.trainer.connectors.signal_connector import _SIGNUM
3637
from lightning.pytorch.trainer.states import TrainerFn, TrainerState
@@ -144,6 +145,9 @@ def _wrapping_function(
144145
) -> None:
145146
if global_states:
146147
global_states.restore()
148+
if self._start_method == "spawn" and isinstance(self._strategy.accelerator, CPUAccelerator):
149+
args, kwargs = _disable_module_memory_sharing((args, kwargs))
150+
147151
os.environ["LOCAL_RANK"] = str(process_idx)
148152
results = function(*args, **kwargs)
149153

tests/tests_fabric/strategies/launchers/test_multiprocessing.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,20 +23,20 @@
2323

2424
@RunIf(skip_windows=True)
2525
@pytest.mark.parametrize("start_method", ["fork", "forkserver"])
26-
def test_multiprocessing_launcher_interactive_compatible(start_method):
26+
def test_interactive_compatible(start_method):
2727
launcher = _MultiProcessingLauncher(strategy=Mock(), start_method=start_method)
2828
assert launcher.is_interactive_compatible == (start_method == "fork")
2929

3030

3131
@mock.patch("lightning.fabric.strategies.launchers.multiprocessing.mp.get_all_start_methods", return_value=[])
32-
def test_multiprocessing_launcher_forking_on_unsupported_platform(_):
32+
def test_forking_on_unsupported_platform(_):
3333
with pytest.raises(ValueError, match="The start method 'fork' is not available on this platform"):
3434
_MultiProcessingLauncher(strategy=Mock(), start_method="fork")
3535

3636

3737
@pytest.mark.parametrize("start_method", ["spawn", pytest.param("fork", marks=RunIf(standalone=True))])
3838
@mock.patch("lightning.fabric.strategies.launchers.multiprocessing.mp")
39-
def test_multiprocessing_launcher_start_method(mp_mock, start_method):
39+
def test_start_method(mp_mock, start_method):
4040
mp_mock.get_all_start_methods.return_value = [start_method]
4141
launcher = _MultiProcessingLauncher(strategy=Mock(), start_method=start_method)
4242
launcher.launch(function=Mock())
@@ -51,7 +51,7 @@ def test_multiprocessing_launcher_start_method(mp_mock, start_method):
5151

5252
@pytest.mark.parametrize("start_method", ["spawn", pytest.param("fork", marks=RunIf(standalone=True))])
5353
@mock.patch("lightning.fabric.strategies.launchers.multiprocessing.mp")
54-
def test_multiprocessing_launcher_restore_globals(mp_mock, start_method):
54+
def test_restore_globals(mp_mock, start_method):
5555
"""Test that we pass the global state snapshot to the worker function only if we are starting with 'spawn'."""
5656
mp_mock.get_all_start_methods.return_value = [start_method]
5757
launcher = _MultiProcessingLauncher(strategy=Mock(), start_method=start_method)
@@ -89,7 +89,7 @@ def test_global_state_snapshot():
8989
@pytest.mark.parametrize("start_method", ["fork", "forkserver"])
9090
@mock.patch("torch.cuda.is_initialized", return_value=True)
9191
@mock.patch("lightning.fabric.strategies.launchers.multiprocessing.mp")
92-
def test_multiprocessing_launcher_check_for_bad_cuda_fork(mp_mock, _, start_method):
92+
def test_check_for_bad_cuda_fork(mp_mock, _, start_method):
9393
mp_mock.get_all_start_methods.return_value = [start_method]
9494
launcher = _MultiProcessingLauncher(strategy=Mock(), start_method=start_method)
9595
with pytest.raises(RuntimeError, match="Lightning can't create new processes if CUDA is already initialized"):
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Copyright The Lightning AI team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import pytest
15+
import torch
16+
import torch.nn as nn
17+
18+
from lightning.fabric import Fabric
19+
from tests_fabric.helpers.runif import RunIf
20+
21+
22+
class SimpleModel(nn.Module):
23+
def __init__(self):
24+
super().__init__()
25+
self.layer = nn.Linear(2, 2)
26+
self.tied_layer = nn.Linear(2, 2)
27+
self.tied_layer.weight = self.layer.weight
28+
self.register_buffer("buffer", torch.ones(3))
29+
30+
31+
@pytest.mark.parametrize("strategy", ["ddp_spawn", pytest.param("ddp_fork", marks=RunIf(skip_windows=True))])
32+
def test_memory_sharing_disabled(strategy):
33+
"""Test that the multiprocessing launcher disables memory sharing on model parameters and buffers to avoid race
34+
conditions on model updates."""
35+
tensor = torch.rand(4)
36+
model = SimpleModel()
37+
assert not tensor.is_shared()
38+
assert not model.layer.weight.is_shared()
39+
assert model.layer.weight.data_ptr() == model.tied_layer.weight.data_ptr()
40+
41+
fabric = Fabric(accelerator="cpu", devices=2, strategy=strategy)
42+
fabric.launch(_test_memory_sharing_disabled, tensor, model)
43+
44+
45+
def _test_memory_sharing_disabled(fabric, tensor, model):
46+
is_spawn = fabric.strategy.launcher._start_method == "spawn"
47+
assert not is_spawn or tensor.is_shared()
48+
assert not model.layer.weight.is_shared()
49+
assert not model.tied_layer.weight.is_shared()
50+
assert not model.buffer.is_shared()
51+
52+
# weights remain tied
53+
assert model.layer.weight.data_ptr() == model.tied_layer.weight.data_ptr()
54+
assert torch.equal(model.layer.weight.data, model.tied_layer.weight.data)
55+
fabric.barrier()

tests/tests_pytorch/strategies/launchers/test_multiprocessing.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,3 +175,31 @@ def test_kill():
175175
with patch("os.kill") as kill_patch:
176176
launcher.kill(15)
177177
assert kill_patch.mock_calls == [call(proc0.pid, 15), call(proc1.pid, 15)]
178+
179+
180+
class SimpleModel(BoringModel):
181+
def __init__(self):
182+
super().__init__()
183+
self.tied_layer = torch.nn.Linear(32, 2)
184+
self.tied_layer.weight = self.layer.weight
185+
self.register_buffer("buffer", torch.ones(3))
186+
187+
def on_fit_start(self) -> None:
188+
assert not self.layer.weight.is_shared()
189+
assert not self.tied_layer.weight.is_shared()
190+
assert not self.buffer.is_shared()
191+
192+
# weights remain tied
193+
assert self.layer.weight.data_ptr() == self.tied_layer.weight.data_ptr()
194+
assert torch.equal(self.layer.weight.data, self.tied_layer.weight.data)
195+
196+
197+
def test_memory_sharing_disabled():
198+
"""Test that the multiprocessing launcher disables memory sharing on model parameters and buffers to avoid race
199+
conditions on model updates."""
200+
model = SimpleModel()
201+
assert not model.layer.weight.is_shared()
202+
assert model.layer.weight.data_ptr() == model.tied_layer.weight.data_ptr()
203+
204+
trainer = Trainer(accelerator="cpu", devices=2, strategy="ddp_spawn", max_steps=0)
205+
trainer.fit(model)

0 commit comments

Comments
 (0)