Skip to content

Commit 15a4959

Browse files
author
Sean Naren
committed
Remove partitioning of model in ZeRO 3 (#10655)
(cherry picked from commit c66cd12)
1 parent e3503a1 commit 15a4959

File tree

5 files changed

+20
-58
lines changed

5 files changed

+20
-58
lines changed

.azure-pipelines/gpu-tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ jobs:
5151
- bash: |
5252
python -c "fname = 'requirements/extra.txt' ; lines = [line for line in open(fname).readlines() if 'horovod' not in line] ; open(fname, 'w').writelines(lines)"
5353
pip install fairscale==0.4.0
54-
pip install deepspeed==0.5.4
54+
pip install deepspeed==0.5.7
5555
pip install . --requirement requirements/devel.txt
5656
pip list
5757
displayName: 'Install dependencies'

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1414
- Fixed an issue when torch-scripting a `LightningModule` after training with `Trainer(sync_batchnorm=True)` ([#11078](https://github.com/PyTorchLightning/pytorch-lightning/pull/11078))
1515
- Fixed an `AttributeError` occuring when using a `CombinedLoader` (multiple dataloaders) for prediction ([#11111](https://github.com/PyTorchLightning/pytorch-lightning/pull/11111))
1616

17+
### Changed
18+
19+
- DeepSpeed does not require lightning module zero 3 partitioning ([#10655](https://github.com/PyTorchLightning/pytorch-lightning/pull/10655))
1720

1821
## [1.5.6] - 2021-12-15
1922

dockers/base-cuda/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ RUN \
112112

113113
RUN \
114114
# install DeepSpeed
115-
pip install deepspeed==0.5.4
115+
pip install deepspeed==0.5.7
116116

117117
RUN \
118118
# Show what we have

pytorch_lightning/plugins/training_type/deepspeed.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,6 @@ def __init__(
129129
contiguous_memory_optimization: bool = False,
130130
synchronize_checkpoint_boundary: bool = False,
131131
load_full_weights: bool = False,
132-
partition_module: bool = True,
133132
) -> None:
134133
"""Provides capabilities to run training using the DeepSpeed library, with training optimizations for large
135134
billion parameter models. `For more information: https://pytorch-
@@ -259,12 +258,6 @@ def __init__(
259258
load_full_weights: True when loading a single checkpoint file containing the model state dict
260259
when using ZeRO Stage 3. This differs from the DeepSpeed checkpoint which contains shards
261260
per worker.
262-
263-
partition_module: When True, partitions the ``LightningModule`` across devices when using ZeRO Stage 3.
264-
This is the default behaviour to ensure that the entire module is appropriately initialized
265-
for DeepSpeed. When False we do not explicitly convert the model, which is fine if NO layers
266-
or ALL layers are defined in ``configure_sharded_model``. This is useful for layers such as
267-
``torch.nn.RNN`` which do internal logic when moving to device.
268261
"""
269262
if not _DEEPSPEED_AVAILABLE:
270263
raise MisconfigurationException(
@@ -317,7 +310,6 @@ def __init__(
317310

318311
self.remote_device = remote_device
319312
self.load_full_weights = load_full_weights
320-
self.partition_module = partition_module
321313

322314
# default FP16 parameters.
323315
self.loss_scale = loss_scale
@@ -463,13 +455,6 @@ def init_deepspeed(self):
463455
precision = self.lightning_module.trainer.accelerator.precision
464456
model = LightningDeepSpeedModule(pl_module=self.model, precision=precision)
465457

466-
if self.zero_stage_3 and self.partition_module:
467-
# Ensure the entire model has been moved to the appropriate device
468-
dtype = torch.float16 if self.precision in (16, "mixed") else torch.float32
469-
deepspeed.zero.Init(
470-
module=model, remote_device=self.remote_device, pin_memory=True, config=self.config, dtype=dtype
471-
)
472-
473458
if self.lightning_module.trainer and self.lightning_module.trainer.training:
474459
self._initialize_deepspeed_train(model)
475460
else:
@@ -524,7 +509,7 @@ def model_sharded_context(self) -> Generator[None, None, None]:
524509
assert self._config_initialized
525510
dtype = torch.float16 if self.precision in (16, "mixed") else torch.float32
526511
model_parallel_context = deepspeed.zero.Init(
527-
remote_device=self.remote_device, pin_memory=True, config=self.config, dtype=dtype
512+
remote_device=self.remote_device, pin_memory=True, config_dict_or_path=self.config, dtype=dtype
528513
)
529514
else:
530515
model_parallel_context = super().model_sharded_context()
@@ -554,7 +539,7 @@ def _initialize_deepspeed_inference(self, model):
554539
optimizer, lr_scheduler, _ = self._init_optimizers()
555540
scheduler = lr_scheduler["scheduler"]
556541
inference_config = {
557-
# todo: this is required for DeepSpeed throughput timers, or throughput timers will be incorrect
542+
# todo: this is required for DeepSpeed throughput timers
558543
"train_micro_batch_size_per_gpu": 1
559544
}
560545
if "fp16" in self.config:

tests/plugins/test_deepspeed_plugin.py

Lines changed: 13 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -595,7 +595,9 @@ def test_deepspeed_multigpu_stage_3_manual_optimization(tmpdir, deepspeed_config
595595
_assert_save_model_is_equal(model, tmpdir, trainer)
596596

597597

598-
def run_checkpoint_test(tmpdir: str, automatic_optimization: bool = True, accumulate_grad_batches: int = 2):
598+
@pytest.mark.parametrize(("accumulate_grad_batches", "automatic_optimization"), [(1, False), (2, True)])
599+
@RunIf(min_gpus=2, deepspeed=True, standalone=True)
600+
def test_deepspeed_multigpu_stage_3_checkpointing(tmpdir, automatic_optimization, accumulate_grad_batches):
599601
seed_everything(1)
600602
if automatic_optimization:
601603
model = ModelParallelClassificationModel()
@@ -630,13 +632,6 @@ def run_checkpoint_test(tmpdir: str, automatic_optimization: bool = True, accumu
630632
assert results[0]["test_acc"] > 0.7
631633

632634

633-
@RunIf(min_gpus=2, deepspeed=True, standalone=True)
634-
def test_deepspeed_multigpu_stage_3_checkpointing(tmpdir):
635-
"""Test to ensure with Stage 3 and multiple GPUs that we can save/load a model resuming from a checkpoint, and
636-
see convergence."""
637-
run_checkpoint_test(tmpdir)
638-
639-
640635
@RunIf(min_gpus=1, deepspeed=True, standalone=True)
641636
def test_deepspeed_multigpu_stage_3_warns_resume_training(tmpdir):
642637
"""Test to ensure with Stage 3 and multiple GPUs that we can resume from training, throwing a warning that the
@@ -718,24 +713,9 @@ def on_train_batch_start(
718713
trainer.fit(model, datamodule=dm, ckpt_path=ck.best_model_path)
719714

720715

716+
@pytest.mark.parametrize("offload_optimizer", [False, True])
721717
@RunIf(min_gpus=2, deepspeed=True, standalone=True)
722-
def test_deepspeed_multigpu_stage_3_checkpointing_full_weights_manual(tmpdir):
723-
"""Test to ensure with Stage 3 and multiple GPUs that we can save/load a model resuming from a checkpoint,
724-
where we save the full weights to one file."""
725-
run_checkpoint_test(tmpdir, automatic_optimization=False, accumulate_grad_batches=1)
726-
727-
728-
@RunIf(min_gpus=2, deepspeed=True, standalone=True)
729-
def test_deepspeed_multigpu_stage_2_accumulated_grad_batches(tmpdir):
730-
_deepspeed_multigpu_stage_2_accumulated_grad_batches(tmpdir, offload_optimizer=False)
731-
732-
733-
@RunIf(min_gpus=2, deepspeed=True, standalone=True)
734-
def test_deepspeed_multigpu_stage_2_accumulated_grad_batches_offload_optimizer(tmpdir):
735-
_deepspeed_multigpu_stage_2_accumulated_grad_batches(tmpdir, offload_optimizer=True)
736-
737-
738-
def _deepspeed_multigpu_stage_2_accumulated_grad_batches(tmpdir, offload_optimizer):
718+
def test_deepspeed_multigpu_stage_2_accumulated_grad_batches(tmpdir, offload_optimizer):
739719
"""Test to ensure with Stage 2 and multiple GPUs, accumulated grad batches works."""
740720
seed_everything(42)
741721

@@ -781,6 +761,8 @@ def test_deepspeed_multigpu_test(tmpdir):
781761
trainer.test(model)
782762

783763

764+
# TODO(Sean): Once partial parameter partitioning is supported this test should be re-enabled
765+
@pytest.mark.skip("Partial parameter partitioning for DeepSpeed is currently broken.")
784766
@RunIf(min_gpus=1, deepspeed=True, standalone=True)
785767
def test_deepspeed_multigpu_partial_partition_parameters(tmpdir):
786768
"""Test to ensure that a module that defines a layer inside the ``__init__`` and ``configure_sharded_model``
@@ -824,7 +806,7 @@ def on_train_epoch_start(self) -> None:
824806
model = TestModel()
825807
trainer = Trainer(
826808
default_root_dir=tmpdir,
827-
strategy=DeepSpeedPlugin(stage=3, partition_module=False),
809+
strategy=DeepSpeedPlugin(stage=3),
828810
gpus=1,
829811
fast_dev_run=True,
830812
precision=16,
@@ -941,22 +923,14 @@ def test_dataloader(self):
941923

942924

943925
@mock.patch("torch.optim.lr_scheduler.StepLR.step", autospec=True)
926+
@pytest.mark.parametrize("interval", ["step", "epoch"])
927+
@pytest.mark.parametrize("max_epoch", [2])
928+
@pytest.mark.parametrize("limit_train_batches", [2])
944929
@RunIf(min_gpus=1, deepspeed=True, standalone=True)
945-
def test_deepspeed_scheduler_step_count(mock_step):
930+
def test_scheduler_step_count(mock_step, max_epoch, limit_train_batches, interval):
946931
"""Test to ensure that the scheduler is called the correct amount of times during training when scheduler is
947-
set to step."""
948-
_run_scheduler_test(mock_step, max_epoch=2, limit_train_batches=2, interval="step")
949-
950-
951-
@mock.patch("torch.optim.lr_scheduler.StepLR.step", autospec=True)
952-
@RunIf(min_gpus=1, deepspeed=True, standalone=True)
953-
def test_deepspeed_scheduler_step_count_epoch(mock_step):
954-
"""Test to ensure that the scheduler is called the correct amount of times during training when scheduler is
955-
set to epoch."""
956-
_run_scheduler_test(mock_step, max_epoch=2, limit_train_batches=2, interval="epoch")
957-
932+
set to step or epoch."""
958933

959-
def _run_scheduler_test(mock_step, max_epoch, limit_train_batches, interval):
960934
class TestModel(BoringModel):
961935
def configure_optimizers(self):
962936
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)

0 commit comments

Comments
 (0)