Skip to content

Commit 41e3be1

Browse files
authored
Remove call_configure_sharded_model lifecycle property (#9612)
1 parent 2b2537d commit 41e3be1

File tree

9 files changed

+22
-129
lines changed

9 files changed

+22
-129
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
351351
- Removed deprecated properties `DeepSpeedPlugin.cpu_offload*` in favor of `offload_optimizer`, `offload_parameters` and `pin_memory` ([#9244](https://github.com/PyTorchLightning/pytorch-lightning/pull/9244))
352352

353353

354+
- Removed `call_configure_sharded_model_hook` property from `Accelerator` and `TrainingTypePlugin` ([#9612](https://github.com/PyTorchLightning/pytorch-lightning/pull/9612))
355+
356+
354357
### Fixed
355358

356359

pytorch_lightning/accelerators/accelerator.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -401,20 +401,6 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: _PATH) -> None:
401401
"""
402402
self.training_type_plugin.save_checkpoint(checkpoint, filepath)
403403

404-
@property
405-
def call_configure_sharded_model_hook(self) -> bool:
406-
"""Allow model parallel hook to be called in suitable environments determined by the training type plugin.
407-
This is useful for when we want to shard the model once within fit.
408-
409-
Returns:
410-
True if we want to call the model parallel setup hook.
411-
"""
412-
return self.training_type_plugin.call_configure_sharded_model_hook
413-
414-
@call_configure_sharded_model_hook.setter
415-
def call_configure_sharded_model_hook(self, mode: bool) -> None:
416-
self.training_type_plugin.call_configure_sharded_model_hook = mode
417-
418404
@property
419405
def setup_optimizers_in_pre_dispatch(self) -> bool:
420406
"""Override to delay setting optimizers and schedulers till after dispatch. This is useful when the

pytorch_lightning/core/hooks.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -297,12 +297,8 @@ def configure_sharded_model(self) -> None:
297297
where we'd like to shard the model instantly, which is useful for extremely large models which can save
298298
memory and initialization time.
299299
300-
The accelerator manages whether to call this hook at every given stage.
301-
For sharded plugins where model parallelism is required, the hook is usually on called once
302-
to initialize the sharded parameters, and not called again in the same process.
303-
304-
By default for accelerators/plugins that do not use model sharding techniques,
305-
this hook is called during each fit/val/test/predict stages.
300+
This hook is called during each of fit/val/test/predict stages in the same process, so ensure that
301+
implementation of this hook is idempotent.
306302
"""
307303

308304

pytorch_lightning/plugins/training_type/fully_sharded.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -141,17 +141,6 @@ def wrap_policy(*args, **kwargs):
141141
):
142142
yield
143143

144-
def setup_environment(self) -> None:
145-
super().setup_environment()
146-
model_call_configure_sharded_model_hook = getattr(
147-
self.lightning_module, "call_configure_sharded_model_hook", False
148-
)
149-
if not model_call_configure_sharded_model_hook:
150-
# if model has not called configure sharded model, we reset
151-
# the training type plugin's call_configure_sharded_model_hook
152-
# to give trainer a chance to configure.
153-
self.call_configure_sharded_model_hook = True
154-
155144
def configure_ddp(self) -> None:
156145
if not self.cpu_offload:
157146
# When using CPU Offload, FSDP will manage the CUDA movement for us.

pytorch_lightning/plugins/training_type/training_type_plugin.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ def __init__(self, checkpoint_io: Optional[CheckpointIO] = None) -> None:
3939
self._results: Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]] = None
4040
checkpoint_io = checkpoint_io if checkpoint_io is not None else TorchCheckpointIO()
4141
self._checkpoint_io = checkpoint_io
42-
self._call_configure_sharded_model_hook = True
4342

4443
@property
4544
def checkpoint_io(self) -> CheckpointIO:
@@ -281,19 +280,6 @@ def model_sharded_context(self) -> Generator:
281280
"""
282281
yield
283282

284-
@property
285-
def call_configure_sharded_model_hook(self) -> bool:
286-
"""Allow model parallel hook to be called in suitable environments determined by the training type plugin.
287-
288-
This is useful for when we want to shard the model once within fit.
289-
Returns: True if we want to call the model parallel setup hook.
290-
"""
291-
return self._call_configure_sharded_model_hook
292-
293-
@call_configure_sharded_model_hook.setter
294-
def call_configure_sharded_model_hook(self, mode: bool) -> None:
295-
self._call_configure_sharded_model_hook = mode
296-
297283
@abstractmethod
298284
def teardown(self) -> None:
299285
"""This method is called to teardown the training process.

pytorch_lightning/trainer/trainer.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1291,18 +1291,9 @@ def _call_setup_hook(self) -> None:
12911291
self.accelerator.barrier("post_setup")
12921292

12931293
def _call_configure_sharded_model(self) -> None:
1294-
# Call configure sharded model hook if accelerator requests. In some cases
1295-
# we will not call the hook; the hook has initialized the sharded model for example.
1296-
1297-
# used on the model if the user re-create a trainer with resume_from_checkpoint
1298-
model = self.lightning_module
1299-
model_call_configure_sharded_model_hook = getattr(model, "call_configure_sharded_model_hook", False)
1300-
if self.accelerator.call_configure_sharded_model_hook and not model_call_configure_sharded_model_hook:
1301-
with self.accelerator.model_sharded_context():
1302-
self.call_hook("configure_sharded_model")
1303-
self.call_hook("on_configure_sharded_model")
1304-
model.call_configure_sharded_model_hook = True
1305-
self.accelerator.call_configure_sharded_model_hook = False
1294+
with self.accelerator.model_sharded_context():
1295+
self.call_hook("configure_sharded_model")
1296+
self.call_hook("on_configure_sharded_model")
13061297

13071298
def _call_teardown_hook(self) -> None:
13081299
fn = self.state.fn._setup_fn

tests/accelerators/test_common.py

Lines changed: 0 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import tests.helpers.utils as tutils
1818
from pytorch_lightning import Trainer
19-
from pytorch_lightning.plugins import SingleDevicePlugin
2019
from tests.accelerators.test_dp import CustomClassificationModelDP
2120
from tests.helpers.boring_model import BoringModel
2221
from tests.helpers.datamodules import ClassifDataModule
@@ -77,57 +76,3 @@ def configure_sharded_model(self):
7776
trainer.fit(model)
7877

7978
assert model.configure_sharded_model_called
80-
81-
82-
class DummyModel(BoringModel):
83-
def __init__(self):
84-
super().__init__()
85-
self.configure_sharded_model_called = False
86-
87-
def configure_sharded_model(self):
88-
self.configure_sharded_model_called = True
89-
90-
91-
def test_configure_sharded_model_false(tmpdir):
92-
"""Ensure ``configure_sharded_model`` is not called, when turned off."""
93-
94-
class CustomPlugin(SingleDevicePlugin):
95-
@property
96-
def call_configure_sharded_model_hook(self) -> bool:
97-
return False
98-
99-
model = DummyModel()
100-
trainer = Trainer(
101-
default_root_dir=tmpdir,
102-
limit_train_batches=2,
103-
limit_val_batches=2,
104-
max_epochs=1,
105-
plugins=CustomPlugin(device=torch.device("cpu")),
106-
)
107-
trainer.fit(model)
108-
109-
assert not model.configure_sharded_model_called
110-
111-
112-
def test_accelerator_configure_sharded_model_called_once(tmpdir):
113-
"""Ensure that the configure sharded model hook is called, and set to False after to ensure not called
114-
again."""
115-
116-
model = DummyModel()
117-
trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, max_epochs=1)
118-
assert trainer.accelerator.call_configure_sharded_model_hook is True
119-
trainer.fit(model)
120-
assert trainer.accelerator.call_configure_sharded_model_hook is False
121-
122-
123-
def test_configure_sharded_model_called_once(tmpdir):
124-
"""Ensure ``configure_sharded_model`` is only called once."""
125-
126-
model = DummyModel()
127-
trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, max_epochs=1)
128-
trainer.fit(model)
129-
130-
assert model.configure_sharded_model_called
131-
model.configure_sharded_model_called = False
132-
133-
assert not model.configure_sharded_model_called

tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -49,27 +49,29 @@ def test_fsdp_with_sharded_amp(device_count_mock, mock_cuda_available, tmpdir):
4949

5050

5151
class TestFSDPModel(BoringModel):
52-
def setup(self, stage: str) -> None:
53-
if stage != "fit":
54-
# when running stages like test, validate, and predict, we will skip setting up,
55-
# will directly use the module itself unless we load from checkpoint
56-
return
57-
# resetting call_configure_sharded_model_hook attribute so that we could call
58-
# configure sharded model
59-
self.call_configure_sharded_model_hook = False
60-
# for loading full state dict, we first need to create a new unwrapped model
61-
# to load state dict and then wrapping
52+
def __init__(self, *args, **kwargs):
53+
super().__init__(*args, **kwargs)
54+
self.layer: Optional[torch.nn.Module] = None
55+
56+
def _init_model(self) -> None:
6257
self.layer = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2))
6358

59+
def setup(self, stage: str) -> None:
60+
if self.layer is None:
61+
self._init_model()
62+
6463
def configure_sharded_model(self) -> None:
64+
# the model is already wrapped with FSDP: no need to wrap again!
65+
if isinstance(self.layer, FullyShardedDataParallel):
66+
return
6567
for i, layer in enumerate(self.layer):
6668
if i % 2 == 0:
6769
self.layer[i] = wrap(layer)
6870
self.layer = wrap(self.layer)
6971

7072
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
7173
# when loading full state dict, we first need to create a new unwrapped model
72-
self.setup("fit")
74+
self._init_model()
7375

7476
def configure_optimizers(self):
7577
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
@@ -131,13 +133,8 @@ def _assert_save_equality(trainer, ckpt_path, cls=TestFSDPModel):
131133
def _run_multiple_stages(trainer, model, model_path: Optional[str] = None):
132134
trainer.fit(model)
133135

134-
model_call_configure_sharded_model_hook = getattr(model, "call_configure_sharded_model_hook", False)
135-
trainer_accelerator_call_configure_sharded_model_hook = trainer.accelerator.call_configure_sharded_model_hook
136-
137136
model_path = model_path if model_path else trainer.checkpoint_callback.last_model_path
138137

139-
assert model_call_configure_sharded_model_hook
140-
assert not trainer_accelerator_call_configure_sharded_model_hook
141138
trainer.save_checkpoint(model_path, weights_only=True)
142139

143140
_assert_save_equality(trainer, model_path, cls=TestFSDPModel)

tests/plugins/test_deepspeed_plugin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -603,7 +603,7 @@ def test_deepspeed_multigpu_stage_3_checkpointing(tmpdir):
603603
run_checkpoint_test(tmpdir)
604604

605605

606-
@RunIf(min_gpus=1, deepspeed=True, special=False)
606+
@RunIf(min_gpus=1, deepspeed=True, special=True)
607607
def test_deepspeed_multigpu_stage_3_warns_resume_training(tmpdir):
608608
"""Test to ensure with Stage 3 and multiple GPUs that we can resume from training, throwing a warning that the
609609
optimizer state and scheduler states cannot be restored."""

0 commit comments

Comments
 (0)