Skip to content

Commit d4d1970

Browse files
awaelchlicarmoccakaushikb11
authored
Add SyncBatchNormPlugin (#11754)
Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: Kaushik B <[email protected]>
1 parent f3c2a13 commit d4d1970

File tree

12 files changed

+196
-89
lines changed

12 files changed

+196
-89
lines changed

CHANGELOG.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
140140
- Added support for pluggable Accelerators ([#12030](https://github.com/PyTorchLightning/pytorch-lightning/pull/12030))
141141

142142

143+
- Added `LayerSync` and `NativeSyncBatchNorm` plugins ([#11754](https://github.com/PyTorchLightning/pytorch-lightning/pull/11754))
144+
145+
146+
143147
### Changed
144148

145149
- Make `benchmark` flag optional and set its value based on the deterministic flag ([#11944](https://github.com/PyTorchLightning/pytorch-lightning/pull/11944))
@@ -629,6 +633,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
629633

630634

631635

636+
- Removed `configure_sync_batchnorm` from `ParallelStrategy` and all other strategies that inherit from it ([#11754](https://github.com/PyTorchLightning/pytorch-lightning/pull/11754))
637+
638+
639+
- Removed public attribute `sync_batchnorm` from strategies ([#11754](https://github.com/PyTorchLightning/pytorch-lightning/pull/11754))
640+
641+
632642
### Fixed
633643

634644
- Fixed an issue where `ModelCheckpoint` could delete older checkpoints when `dirpath` has changed during resumed training ([#12045](https://github.com/PyTorchLightning/pytorch-lightning/pull/12045))

docs/source/api_references.rst

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,21 @@ Checkpoint IO Plugins
237237
TorchCheckpointIO
238238
XLACheckpointIO
239239

240+
241+
Other Plugins
242+
^^^^^^^^^^^^^
243+
244+
.. currentmodule:: pytorch_lightning.plugins
245+
246+
.. autosummary::
247+
:toctree: api
248+
:nosignatures:
249+
:template: classtemplate.rst
250+
251+
LayerSync
252+
NativeSyncBatchNorm
253+
254+
240255
Profiler API
241256
------------
242257

pytorch_lightning/plugins/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
55
from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO
66
from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO
7+
from pytorch_lightning.plugins.layer_sync import LayerSync, NativeSyncBatchNorm
78
from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin
89
from pytorch_lightning.plugins.precision.deepspeed import DeepSpeedPrecisionPlugin
910
from pytorch_lightning.plugins.precision.double import DoublePrecisionPlugin
@@ -31,7 +32,7 @@
3132
from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin
3233
from pytorch_lightning.strategies import Strategy
3334

34-
PLUGIN = Union[Strategy, PrecisionPlugin, ClusterEnvironment, CheckpointIO]
35+
PLUGIN = Union[Strategy, PrecisionPlugin, ClusterEnvironment, CheckpointIO, LayerSync]
3536
PLUGIN_INPUT = Union[PLUGIN, str]
3637

3738
__all__ = [
@@ -63,4 +64,6 @@
6364
"ParallelPlugin",
6465
"DDPShardedPlugin",
6566
"DDPSpawnShardedPlugin",
67+
"LayerSync",
68+
"NativeSyncBatchNorm",
6669
]
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from abc import ABC, abstractmethod
16+
17+
import torch
18+
from torch.nn import Module
19+
20+
21+
class LayerSync(ABC):
22+
"""Abstract base class for creating plugins that wrap layers of a model with synchronization logic for
23+
multiprocessing."""
24+
25+
@abstractmethod
26+
def apply(self, model: Module) -> Module:
27+
"""Override this method to apply synchronization to the layers of this model."""
28+
29+
@abstractmethod
30+
def revert(self, model: Module) -> Module:
31+
"""Override this method to undo all modifications made in :meth:`apply`."""
32+
33+
34+
class NativeSyncBatchNorm(LayerSync):
35+
"""A plugin that wraps all batch normalization layers of a model with synchronization logic for
36+
multiprocessing.
37+
38+
This plugin has no effect in single-device operation.
39+
"""
40+
41+
def apply(self, model: Module) -> Module:
42+
"""Add global batchnorm for a model spread across multiple GPUs and nodes.
43+
44+
Override this method to synchronize batchnorm layers between specific process groups instead
45+
of the whole world.
46+
47+
Args:
48+
model: Reference to the current LightningModule
49+
50+
Return:
51+
LightningModule with batchnorm layers synchronized within the process groups.
52+
"""
53+
return torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
54+
55+
def revert(self, model: Module) -> Module:
56+
"""Convert the wrapped batchnorm layers back to regular batchnorm layers.
57+
58+
Args:
59+
model: Reference to the current LightningModule
60+
61+
Return:
62+
LightningModule with regular batchnorm layers that will no longer sync across processes.
63+
"""
64+
# Code adapted from https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547
65+
# Original author: Kapil Yedidi (@kapily)
66+
converted_module = model
67+
if isinstance(model, torch.nn.modules.batchnorm.SyncBatchNorm):
68+
# Unfortunately, LayerSync does not store the original class - if it did
69+
# we could return the one that was originally created.
70+
converted_module = _BatchNormXd(
71+
model.num_features, model.eps, model.momentum, model.affine, model.track_running_stats
72+
)
73+
if model.affine:
74+
with torch.no_grad():
75+
converted_module.weight = model.weight
76+
converted_module.bias = model.bias
77+
converted_module.running_mean = model.running_mean
78+
converted_module.running_var = model.running_var
79+
converted_module.num_batches_tracked = model.num_batches_tracked
80+
if hasattr(model, "qconfig"):
81+
converted_module.qconfig = model.qconfig
82+
for name, child in model.named_children():
83+
converted_module.add_module(name, self.revert(child))
84+
del model
85+
return converted_module
86+
87+
88+
class _BatchNormXd(torch.nn.modules.batchnorm._BatchNorm):
89+
def _check_input_dim(self, input: torch.Tensor) -> None:
90+
# The only difference between BatchNorm1d, BatchNorm2d, BatchNorm3d, etc
91+
# is this method that is overwritten by the subclass.
92+
# Here, we are bypassing some tensor sanity checks and trusting that the user
93+
# provides the right input dimensions at inference.
94+
return

pytorch_lightning/strategies/ddp.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
_TORCH_GREATER_EQUAL_1_9,
4343
_TORCH_GREATER_EQUAL_1_10,
4444
)
45-
from pytorch_lightning.utilities.distributed import _revert_sync_batchnorm, distributed_available
45+
from pytorch_lightning.utilities.distributed import distributed_available
4646
from pytorch_lightning.utilities.distributed import group as _group
4747
from pytorch_lightning.utilities.distributed import init_dist_connection, ReduceOp, sync_ddp_if_available
4848
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException
@@ -86,7 +86,6 @@ def __init__(
8686
)
8787
log.detail(f"{self.__class__.__name__}: initializing DDP plugin")
8888
self._num_nodes = 1
89-
self.sync_batchnorm = False
9089
self._ddp_kwargs = kwargs
9190
self._ddp_comm_state = ddp_comm_state
9291
self._ddp_comm_hook = ddp_comm_hook
@@ -145,8 +144,8 @@ def setup(self, trainer: "pl.Trainer") -> None:
145144
# move the model to the correct device
146145
self.model_to_device()
147146

148-
if self.sync_batchnorm:
149-
self.model = self.configure_sync_batchnorm(self.model)
147+
if self._layer_sync:
148+
self.model = self._layer_sync.apply(self.model)
150149

151150
# skip wrapping the model if we are not fitting as no gradients need to be exchanged
152151
trainer_fn = trainer.state.fn
@@ -422,8 +421,8 @@ def teardown(self) -> None:
422421
if isinstance(self.model, DistributedDataParallel):
423422
self.model = self.lightning_module
424423

425-
if self.sync_batchnorm:
426-
self.model = _revert_sync_batchnorm(self.model)
424+
if self._layer_sync:
425+
self.model = self._layer_sync.revert(self.model)
427426

428427
if self.root_device.type == "cuda":
429428
# GPU teardown

pytorch_lightning/strategies/ddp_spawn.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from pytorch_lightning.strategies.parallel import ParallelStrategy
3131
from pytorch_lightning.trainer.states import TrainerFn
3232
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8
33-
from pytorch_lightning.utilities.distributed import _revert_sync_batchnorm, distributed_available
33+
from pytorch_lightning.utilities.distributed import distributed_available
3434
from pytorch_lightning.utilities.distributed import group as _group
3535
from pytorch_lightning.utilities.distributed import init_dist_connection, ReduceOp, sync_ddp_if_available
3636
from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn
@@ -69,7 +69,6 @@ def __init__(
6969
precision_plugin=precision_plugin,
7070
)
7171
self._num_nodes = 1
72-
self.sync_batchnorm = False
7372
self._ddp_kwargs = kwargs
7473
self._ddp_comm_state = ddp_comm_state
7574
self._ddp_comm_hook = ddp_comm_hook
@@ -116,8 +115,8 @@ def setup(self, trainer: "pl.Trainer") -> None:
116115
# move the model to the correct device
117116
self.model_to_device()
118117

119-
if self.sync_batchnorm:
120-
self.model = self.configure_sync_batchnorm(self.model)
118+
if self._layer_sync:
119+
self.model = self._layer_sync.apply(self.model)
121120

122121
# skip wrapping the model if we are not fitting as no gradients need to be exchanged
123122
trainer_fn = self.lightning_module.trainer.state.fn
@@ -269,8 +268,8 @@ def teardown(self) -> None:
269268
if isinstance(self.model, DistributedDataParallel):
270269
self.model = self.lightning_module
271270

272-
if self.sync_batchnorm:
273-
self.model = _revert_sync_batchnorm(self.model)
271+
if self._layer_sync:
272+
self.model = self._layer_sync.revert(self.model)
274273

275274
if self.root_device.type == "cuda":
276275
# GPU teardown

pytorch_lightning/strategies/fully_sharded.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,8 @@ def setup(self, trainer: "pl.Trainer") -> None:
139139
self.setup_precision_plugin()
140140
optimizers_to_device(self.optimizers, self.root_device)
141141

142-
if self.sync_batchnorm:
143-
self.model = self.configure_sync_batchnorm(self.model)
142+
if self._layer_sync:
143+
self.model = self._layer_sync.apply(self.model)
144144

145145
self.configure_ddp()
146146
self.barrier()

pytorch_lightning/strategies/parallel.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import pytorch_lightning as pl
2323
from pytorch_lightning.overrides.base import unwrap_lightning_module
24+
from pytorch_lightning.plugins import LayerSync
2425
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
2526
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
2627
from pytorch_lightning.plugins.precision import PrecisionPlugin
@@ -42,6 +43,7 @@ def __init__(
4243
super().__init__(accelerator=accelerator, checkpoint_io=checkpoint_io, precision_plugin=precision_plugin)
4344
self.parallel_devices = parallel_devices
4445
self.cluster_environment = cluster_environment
46+
self._layer_sync: Optional[LayerSync] = None
4547

4648
@property
4749
@abstractmethod
@@ -105,21 +107,6 @@ def torch_distributed_backend(self):
105107
torch_backend = "nccl" if self.root_device.type == "cuda" else "gloo"
106108
return torch_backend
107109

108-
@staticmethod
109-
def configure_sync_batchnorm(model: "pl.LightningModule") -> "pl.LightningModule":
110-
"""Add global batchnorm for a model spread across multiple GPUs and nodes.
111-
112-
Override to synchronize batchnorm between specific process groups instead
113-
of the whole world or use a different sync_bn like `apex`'s version.
114-
115-
Args:
116-
model: pointer to current :class:`LightningModule`.
117-
118-
Return:
119-
LightningModule with batchnorm layers synchronized between process groups
120-
"""
121-
return torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
122-
123110
@contextmanager
124111
def block_backward_sync(self):
125112
"""Blocks ddp sync gradients behaviour on backwards pass.

pytorch_lightning/trainer/connectors/accelerator_connector.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
SLURMEnvironment,
4747
TorchElasticEnvironment,
4848
)
49+
from pytorch_lightning.plugins.layer_sync import LayerSync, NativeSyncBatchNorm
4950
from pytorch_lightning.strategies import (
5051
DDP2Strategy,
5152
DDPFullyShardedStrategy,
@@ -150,7 +151,6 @@ def __init__(
150151
# TODO: move to gpu accelerator
151152
torch.backends.cudnn.benchmark = self.benchmark
152153
self.replace_sampler_ddp = replace_sampler_ddp
153-
self.sync_batchnorm = sync_batchnorm
154154
self._init_deterministic(deterministic)
155155

156156
# 1. Parsing flags
@@ -169,6 +169,7 @@ def __init__(
169169
self._precision_plugin_flag: Optional[PrecisionPlugin] = None
170170
self._cluster_environment_flag: Optional[Union[ClusterEnvironment, str]] = None
171171
self._parallel_devices: List[Union[int, torch.device]] = []
172+
self._layer_sync: Optional[LayerSync] = NativeSyncBatchNorm() if sync_batchnorm else None
172173
self.checkpoint_io: Optional[CheckpointIO] = None
173174
self._amp_type_flag: Optional[LightningEnum] = None
174175
self._amp_level_flag: Optional[str] = amp_level
@@ -180,6 +181,7 @@ def __init__(
180181
plugins=plugins,
181182
amp_type=amp_type,
182183
amp_level=amp_level,
184+
sync_batchnorm=sync_batchnorm,
183185
)
184186
self._check_device_config_and_set_final_flags(
185187
devices=devices, num_nodes=num_nodes, num_processes=num_processes, gpus=gpus, ipus=ipus, tpu_cores=tpu_cores
@@ -230,6 +232,7 @@ def _check_config_and_set_final_flags(
230232
plugins: Optional[Union[PLUGIN_INPUT, List[PLUGIN_INPUT]]],
231233
amp_type: str,
232234
amp_level: Optional[str],
235+
sync_batchnorm: bool,
233236
) -> None:
234237
"""This method checks:
235238
@@ -317,6 +320,13 @@ def _check_config_and_set_final_flags(
317320
self.checkpoint_io = plugin
318321
elif isinstance(plugin, ClusterEnvironment):
319322
self._cluster_environment_flag = plugin
323+
elif isinstance(plugin, LayerSync):
324+
if sync_batchnorm and not isinstance(plugin, NativeSyncBatchNorm):
325+
raise MisconfigurationException(
326+
f"You set `Trainer(sync_batchnorm=True)` and provided a `{plugin.__class__.__name__}`"
327+
" plugin, but this is not allowed. Choose one or the other."
328+
)
329+
self._layer_sync = plugin
320330
else:
321331
raise MisconfigurationException(
322332
f"Found invalid type for plugin {plugin}. Expected a precision plugin or training strategy."
@@ -715,8 +725,8 @@ def _lazy_init_strategy(self) -> None:
715725
self.strategy.parallel_devices = self._parallel_devices
716726
if hasattr(self.strategy, "num_nodes"):
717727
self.strategy._num_nodes = self._num_nodes_flag
718-
if hasattr(self.strategy, "sync_batchnorm"):
719-
self.strategy.sync_batchnorm = self.sync_batchnorm
728+
if hasattr(self.strategy, "_layer_sync"):
729+
self.strategy._layer_sync = self._layer_sync
720730
if hasattr(self.strategy, "set_world_ranks"):
721731
self.strategy.set_world_ranks()
722732
self.strategy._configure_launcher()

0 commit comments

Comments
 (0)