Skip to content

Commit 7aee00c

Browse files
jona-0pre-commit-ci[bot]Sean Narenananthsubawaelchli
authored
[DeepSpeed] fix flag forwarding in DeepSpeedPlugin (#10899)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Sean Naren <[email protected]> Co-authored-by: ananthsub <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]>
1 parent d7b6e87 commit 7aee00c

File tree

3 files changed

+33
-3
lines changed

3 files changed

+33
-3
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
261261
- Fixed support for `CombinedLoader` while checking for warning raised with eval dataloaders ([#10994](https://github.com/PyTorchLightning/pytorch-lightning/pull/10994))
262262

263263

264-
-
264+
- Fixed a bug where the DeepSpeedPlugin arguments `cpu_checkpointing` and `contiguous_memory_optimization` were not being forwarded to deepspeed correctly ([#10874](https://github.com/PyTorchLightning/pytorch-lightning/issues/10874))
265265

266266

267267
-

pytorch_lightning/plugins/training_type/deepspeed.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -527,8 +527,8 @@ def _set_deepspeed_activation_checkpointing(self):
527527
deepspeed.checkpointing.configure(
528528
mpu_=None,
529529
partition_activations=checkpoint_config.get("partition_activations"),
530-
contiguous_checkpointing=checkpoint_config.get("contiguous_checkpointing"),
531-
checkpoint_in_cpu=checkpoint_config.get("checkpoint_in_cpu"),
530+
contiguous_checkpointing=checkpoint_config.get("contiguous_memory_optimization"),
531+
checkpoint_in_cpu=checkpoint_config.get("cpu_checkpointing"),
532532
profile=checkpoint_config.get("profile"),
533533
)
534534

tests/plugins/test_deepspeed_plugin.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,36 @@ def test_deepspeed_custom_activation_checkpointing_params(tmpdir):
361361
assert checkpoint_config["synchronize_checkpoint_boundary"]
362362

363363

364+
@RunIf(min_gpus=1, deepspeed=True, standalone=True)
365+
def test_deepspeed_custom_activation_checkpointing_params_forwarded(tmpdir):
366+
"""Ensure if we modify the activation checkpointing parameters, we pass these to
367+
deepspeed.checkpointing.configure correctly."""
368+
ds = DeepSpeedPlugin(
369+
partition_activations=True,
370+
cpu_checkpointing=True,
371+
contiguous_memory_optimization=True,
372+
synchronize_checkpoint_boundary=True,
373+
)
374+
375+
model = BoringModel()
376+
trainer = Trainer(
377+
default_root_dir=tmpdir,
378+
enable_progress_bar=False,
379+
fast_dev_run=1,
380+
strategy=ds,
381+
precision=16,
382+
gpus=1,
383+
)
384+
with mock.patch(
385+
"deepspeed.checkpointing.configure", wraps=deepspeed.checkpointing.configure
386+
) as deepspeed_checkpointing_configure:
387+
trainer.fit(model)
388+
389+
deepspeed_checkpointing_configure.assert_called_with(
390+
mpu_=None, partition_activations=True, contiguous_checkpointing=True, checkpoint_in_cpu=True, profile=None
391+
)
392+
393+
364394
@RunIf(min_gpus=1, deepspeed=True)
365395
def test_deepspeed_assert_config_zero_offload_disabled(tmpdir, deepspeed_zero_config):
366396
"""Ensure if we use a config and turn off offload_optimizer, that this is set to False within the config."""

0 commit comments

Comments
 (0)