Skip to content

Commit 0374fe6

Browse files
krshrimalipre-commit-ci[bot]rohitgr7ananthsubcarmocca
authored
Support gradient accumulation using Horovod's backward_passes_per_step (#11911)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Rohit Gupta <[email protected]> Co-authored-by: ananthsub <[email protected]> Co-authored-by: Carlos Mocholí <[email protected]>
1 parent cf64f34 commit 0374fe6

File tree

3 files changed

+94
-21
lines changed

3 files changed

+94
-21
lines changed

CHANGELOG.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
99

1010
### Added
1111

12-
- Add new `DETAIL` log level to provide useful logs for improving monitoring and debugging of batch jobs
12+
- Enable gradient accumulation using Horovod's `backward_passes_per_step` ([#11911](https://github.com/PyTorchLightning/pytorch-lightning/pull/11911))
13+
14+
15+
- Add new `DETAIL` log level to provide useful logs for improving monitoring and debugging of batch jobs ([#11008](https://github.com/PyTorchLightning/pytorch-lightning/pull/11008))
1316

1417

1518
- Added a flag `SLURMEnvironment(auto_requeue=True|False)` to control whether Lightning handles the requeuing ([#10601](https://github.com/PyTorchLightning/pytorch-lightning/pull/10601))

pytorch_lightning/strategies/horovod.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from pytorch_lightning.utilities.distributed import distributed_available
2727
from pytorch_lightning.utilities.distributed import group as dist_group
2828
from pytorch_lightning.utilities.distributed import ReduceOp
29+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2930
from pytorch_lightning.utilities.imports import _HOROVOD_AVAILABLE
3031
from pytorch_lightning.utilities.rank_zero import rank_zero_only
3132

@@ -76,6 +77,11 @@ def distributed_sampler_kwargs(self):
7677
distributed_sampler_kwargs = dict(num_replicas=self.world_size, rank=self.global_rank)
7778
return distributed_sampler_kwargs
7879

80+
@property
81+
def handles_gradient_accumulation(self) -> bool:
82+
"""Whether the plugin handles gradient accumulation internally."""
83+
return True
84+
7985
def setup(self, trainer: "pl.Trainer") -> None:
8086
self.model_to_device()
8187

@@ -111,7 +117,13 @@ def _unpack_lightning_optimizer(opt):
111117
for optimizer in optimizers:
112118
hvd.broadcast_optimizer_state(optimizer, root_rank=0)
113119

114-
self.optimizers = self._wrap_optimizers(optimizers)
120+
accumulation_scheduler = trainer.accumulation_scheduler
121+
if accumulation_scheduler.epochs != [0]:
122+
raise MisconfigurationException(
123+
"Horovod currently does not support different `accumulate_grad_batches` at different epochs."
124+
)
125+
126+
self.optimizers = self._wrap_optimizers(optimizers, trainer.accumulate_grad_batches)
115127
for optimizer in self.optimizers:
116128
# Synchronization will be performed explicitly following backward()
117129
self._exit_stack.enter_context(optimizer.skip_synchronize())
@@ -181,10 +193,16 @@ def post_backward(self, closure_loss: torch.Tensor) -> None:
181193
for optimizer in self.optimizers:
182194
optimizer.synchronize()
183195

184-
def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["hvd.DistributedOptimizer"]:
196+
def _wrap_optimizers(
197+
self, optimizers: List[Optimizer], accumulate_grad_batches: int
198+
) -> List["hvd.DistributedOptimizer"]:
185199
"""Wraps optimizers to perform gradient aggregation via allreduce."""
186200
return [
187-
hvd.DistributedOptimizer(opt, named_parameters=self._filter_named_parameters(self.lightning_module, opt))
201+
hvd.DistributedOptimizer(
202+
opt,
203+
backward_passes_per_step=accumulate_grad_batches,
204+
named_parameters=self._filter_named_parameters(self.lightning_module, opt),
205+
)
188206
if "horovod" not in str(opt.__class__)
189207
else opt
190208
for opt in optimizers

tests/models/test_horovod.py

Lines changed: 69 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from pytorch_lightning import Trainer
3131
from pytorch_lightning.accelerators import CPUAccelerator
3232
from pytorch_lightning.utilities import _HOROVOD_AVAILABLE
33+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3334
from tests.helpers import BoringModel
3435
from tests.helpers.advanced_models import BasicGAN
3536
from tests.helpers.runif import RunIf
@@ -42,25 +43,23 @@
4243
TEST_SCRIPT = os.path.join(os.path.dirname(__file__), "data", "horovod", "train_default_model.py")
4344

4445

45-
def _run_horovod(trainer_options, on_gpu=False):
46+
def _run_horovod(trainer_options):
4647
"""Execute the training script across multiple workers in parallel."""
47-
num_processes = trainer_options.get("gpus", 2)
48-
# for Horovod, we interpret `gpus` to be set per worker
49-
trainer_options.update(gpus=1 if on_gpu else None)
48+
devices = trainer_options.get("devices", 1)
5049
tutils.reset_seed()
5150
# TODO: Find out why coverage breaks CI.
5251
# append = '-a' if '.coverage' in os.listdir(_PROJECT_ROOT) else ''
5352
# str(num_processes), sys.executable, '-m', 'coverage', 'run', '--source', 'pytorch_lightning', append,
5453
cmdline = [
5554
"horovodrun",
5655
"-np",
57-
str(num_processes),
56+
str(devices),
5857
sys.executable,
5958
TEST_SCRIPT,
6059
"--trainer-options",
6160
shlex.quote(json.dumps(trainer_options)),
6261
]
63-
if on_gpu:
62+
if trainer_options.get("accelerator", "cpu") == "gpu":
6463
cmdline += ["--on-gpu"]
6564
exit_code = subprocess.call(" ".join(cmdline), shell=True, env=os.environ.copy())
6665
assert exit_code == 0
@@ -82,6 +81,20 @@ def test_horovod_cpu(tmpdir):
8281
_run_horovod(trainer_options)
8382

8483

84+
@RunIf(skip_windows=True, horovod=True, skip_49370=True)
85+
def test_horovod_cpu_accumulate_grad_batches(tmpdir):
86+
trainer_options = dict(
87+
default_root_dir=tmpdir,
88+
enable_progress_bar=False,
89+
max_epochs=1,
90+
limit_train_batches=4,
91+
limit_val_batches=0,
92+
accumulate_grad_batches=2,
93+
strategy="horovod",
94+
)
95+
_run_horovod(trainer_options)
96+
97+
8598
@RunIf(skip_windows=True, horovod=True, skip_49370=True)
8699
def test_horovod_cpu_clip_grad_by_value(tmpdir):
87100
"""Test Horovod running multi-process on CPU."""
@@ -125,10 +138,44 @@ def test_horovod_multi_gpu(tmpdir):
125138
max_epochs=1,
126139
limit_train_batches=0.4,
127140
limit_val_batches=0.2,
128-
gpus=2,
141+
accelerator="gpu",
142+
devices=2,
143+
strategy="horovod",
144+
)
145+
_run_horovod(trainer_options)
146+
147+
148+
@RunIf(min_gpus=2, skip_windows=True, horovod_nccl=True)
149+
def test_horovod_multi_gpu_accumulate_grad_batches(tmpdir):
150+
trainer_options = dict(
151+
default_root_dir=tmpdir,
152+
enable_progress_bar=False,
153+
max_epochs=1,
154+
limit_train_batches=4,
155+
limit_val_batches=0,
156+
accumulate_grad_batches=2,
157+
accelerator="gpu",
158+
devices=2,
129159
strategy="horovod",
130160
)
131-
_run_horovod(trainer_options, on_gpu=True)
161+
_run_horovod(trainer_options)
162+
163+
164+
@RunIf(horovod=True, skip_windows=True)
165+
def test_horovod_raises_unsupported_accumulate_grad_batches(tmpdir):
166+
"""Ensure MisConfigurationException for different `accumulate_grad_batches` at different epochs for Horovod
167+
Strategy on multi-gpus."""
168+
model = BoringModel()
169+
trainer = Trainer(
170+
default_root_dir=tmpdir,
171+
enable_progress_bar=False,
172+
accumulate_grad_batches={0: 4, 2: 2},
173+
accelerator="auto",
174+
devices=1,
175+
strategy="horovod",
176+
)
177+
with pytest.raises(MisconfigurationException, match="Horovod.*does not support.*accumulate_grad_batches"):
178+
trainer.fit(model)
132179

133180

134181
@RunIf(min_gpus=2, skip_windows=True, horovod_nccl=True)
@@ -143,10 +190,11 @@ def test_horovod_multi_gpu_grad_by_value(tmpdir):
143190
max_epochs=1,
144191
limit_train_batches=0.4,
145192
limit_val_batches=0.2,
146-
gpus=2,
193+
accelerator="gpu",
194+
devices=2,
147195
strategy="horovod",
148196
)
149-
_run_horovod(trainer_options, on_gpu=True)
197+
_run_horovod(trainer_options)
150198

151199

152200
# todo: need to be fixed :]
@@ -164,12 +212,13 @@ def test_horovod_apex(tmpdir):
164212
max_epochs=1,
165213
limit_train_batches=0.4,
166214
limit_val_batches=0.2,
167-
gpus=2,
215+
accelerator="gpu",
216+
devices=2,
168217
strategy="horovod",
169218
amp_backend="apex",
170219
precision=16,
171220
)
172-
_run_horovod(trainer_options, on_gpu=True)
221+
_run_horovod(trainer_options)
173222

174223

175224
@RunIf(min_gpus=2, skip_windows=True, horovod_nccl=True)
@@ -183,12 +232,13 @@ def test_horovod_amp(tmpdir):
183232
max_epochs=1,
184233
limit_train_batches=0.4,
185234
limit_val_batches=0.2,
186-
gpus=2,
235+
accelerator="gpu",
236+
devices=2,
187237
strategy="horovod",
188238
amp_backend="native",
189239
precision=16,
190240
)
191-
_run_horovod(trainer_options, on_gpu=True)
241+
_run_horovod(trainer_options)
192242

193243

194244
@RunIf(min_gpus=2, skip_windows=True, horovod_nccl=True)
@@ -202,10 +252,11 @@ def test_horovod_gather(tmpdir):
202252
max_epochs=1,
203253
limit_train_batches=0.4,
204254
limit_val_batches=0.2,
205-
gpus=2,
255+
accelerator="gpu",
256+
devices=2,
206257
strategy="horovod",
207258
)
208-
_run_horovod(trainer_options, on_gpu=True)
259+
_run_horovod(trainer_options)
209260

210261

211262
@RunIf(min_gpus=1, skip_windows=True, horovod_nccl=True)
@@ -227,7 +278,8 @@ def validation_step(self, batch, *args, **kwargs):
227278
max_epochs=1,
228279
limit_train_batches=0.4,
229280
limit_val_batches=0.2,
230-
gpus=1,
281+
accelerator="gpu",
282+
devices=1,
231283
strategy="horovod",
232284
)
233285
tpipes.run_model_test_without_loggers(trainer_options, model)

0 commit comments

Comments
 (0)