Skip to content

Commit 2aa4497

Browse files
committed
Removec call_configure_sharded_model lifecycle property
1 parent d022f6f commit 2aa4497

File tree

6 files changed

+16
-122
lines changed

6 files changed

+16
-122
lines changed

pytorch_lightning/accelerators/accelerator.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -401,20 +401,6 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: _PATH) -> None:
401401
"""
402402
self.training_type_plugin.save_checkpoint(checkpoint, filepath)
403403

404-
@property
405-
def call_configure_sharded_model_hook(self) -> bool:
406-
"""Allow model parallel hook to be called in suitable environments determined by the training type plugin.
407-
This is useful for when we want to shard the model once within fit.
408-
409-
Returns:
410-
True if we want to call the model parallel setup hook.
411-
"""
412-
return self.training_type_plugin.call_configure_sharded_model_hook
413-
414-
@call_configure_sharded_model_hook.setter
415-
def call_configure_sharded_model_hook(self, mode: bool) -> None:
416-
self.training_type_plugin.call_configure_sharded_model_hook = mode
417-
418404
@property
419405
def setup_optimizers_in_pre_dispatch(self) -> bool:
420406
"""Override to delay setting optimizers and schedulers till after dispatch. This is useful when the

pytorch_lightning/plugins/training_type/fully_sharded.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -141,17 +141,6 @@ def wrap_policy(*args, **kwargs):
141141
):
142142
yield
143143

144-
def setup_environment(self) -> None:
145-
super().setup_environment()
146-
model_call_configure_sharded_model_hook = getattr(
147-
self.lightning_module, "call_configure_sharded_model_hook", False
148-
)
149-
if not model_call_configure_sharded_model_hook:
150-
# if model has not called configure sharded model, we reset
151-
# the training type plugin's call_configure_sharded_model_hook
152-
# to give trainer a chance to configure.
153-
self.call_configure_sharded_model_hook = True
154-
155144
def configure_ddp(self) -> None:
156145
if not self.cpu_offload:
157146
# When using CPU Offload, FSDP will manage the CUDA movement for us.

pytorch_lightning/plugins/training_type/training_type_plugin.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ def __init__(self, checkpoint_io: Optional[CheckpointIO] = None) -> None:
3939
self._results: Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]] = None
4040
checkpoint_io = checkpoint_io if checkpoint_io is not None else TorchCheckpointIO()
4141
self._checkpoint_io = checkpoint_io
42-
self._call_configure_sharded_model_hook = True
4342

4443
@property
4544
def checkpoint_io(self) -> CheckpointIO:
@@ -281,19 +280,6 @@ def model_sharded_context(self) -> Generator:
281280
"""
282281
yield
283282

284-
@property
285-
def call_configure_sharded_model_hook(self) -> bool:
286-
"""Allow model parallel hook to be called in suitable environments determined by the training type plugin.
287-
288-
This is useful for when we want to shard the model once within fit.
289-
Returns: True if we want to call the model parallel setup hook.
290-
"""
291-
return self._call_configure_sharded_model_hook
292-
293-
@call_configure_sharded_model_hook.setter
294-
def call_configure_sharded_model_hook(self, mode: bool) -> None:
295-
self._call_configure_sharded_model_hook = mode
296-
297283
@abstractmethod
298284
def teardown(self) -> None:
299285
"""This method is called to teardown the training process.

pytorch_lightning/trainer/trainer.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1276,18 +1276,9 @@ def _call_setup_hook(self) -> None:
12761276
self.accelerator.barrier("post_setup")
12771277

12781278
def _call_configure_sharded_model(self) -> None:
1279-
# Call configure sharded model hook if accelerator requests. In some cases
1280-
# we will not call the hook; the hook has initialized the sharded model for example.
1281-
1282-
# used on the model if the user re-create a trainer with resume_from_checkpoint
1283-
model = self.lightning_module
1284-
model_call_configure_sharded_model_hook = getattr(model, "call_configure_sharded_model_hook", False)
1285-
if self.accelerator.call_configure_sharded_model_hook and not model_call_configure_sharded_model_hook:
1286-
with self.accelerator.model_sharded_context():
1287-
self.call_hook("configure_sharded_model")
1288-
self.call_hook("on_configure_sharded_model")
1289-
model.call_configure_sharded_model_hook = True
1290-
self.accelerator.call_configure_sharded_model_hook = False
1279+
with self.accelerator.model_sharded_context():
1280+
self.call_hook("configure_sharded_model")
1281+
self.call_hook("on_configure_sharded_model")
12911282

12921283
def _call_teardown_hook(self) -> None:
12931284
fn = self.state.fn._setup_fn

tests/accelerators/test_common.py

Lines changed: 0 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import tests.helpers.utils as tutils
1818
from pytorch_lightning import Trainer
19-
from pytorch_lightning.plugins import SingleDevicePlugin
2019
from tests.accelerators.test_dp import CustomClassificationModelDP
2120
from tests.helpers.boring_model import BoringModel
2221
from tests.helpers.datamodules import ClassifDataModule
@@ -77,57 +76,3 @@ def configure_sharded_model(self):
7776
trainer.fit(model)
7877

7978
assert model.configure_sharded_model_called
80-
81-
82-
class DummyModel(BoringModel):
83-
def __init__(self):
84-
super().__init__()
85-
self.configure_sharded_model_called = False
86-
87-
def configure_sharded_model(self):
88-
self.configure_sharded_model_called = True
89-
90-
91-
def test_configure_sharded_model_false(tmpdir):
92-
"""Ensure ``configure_sharded_model`` is not called, when turned off."""
93-
94-
class CustomPlugin(SingleDevicePlugin):
95-
@property
96-
def call_configure_sharded_model_hook(self) -> bool:
97-
return False
98-
99-
model = DummyModel()
100-
trainer = Trainer(
101-
default_root_dir=tmpdir,
102-
limit_train_batches=2,
103-
limit_val_batches=2,
104-
max_epochs=1,
105-
plugins=CustomPlugin(device=torch.device("cpu")),
106-
)
107-
trainer.fit(model)
108-
109-
assert not model.configure_sharded_model_called
110-
111-
112-
def test_accelerator_configure_sharded_model_called_once(tmpdir):
113-
"""Ensure that the configure sharded model hook is called, and set to False after to ensure not called
114-
again."""
115-
116-
model = DummyModel()
117-
trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, max_epochs=1)
118-
assert trainer.accelerator.call_configure_sharded_model_hook is True
119-
trainer.fit(model)
120-
assert trainer.accelerator.call_configure_sharded_model_hook is False
121-
122-
123-
def test_configure_sharded_model_called_once(tmpdir):
124-
"""Ensure ``configure_sharded_model`` is only called once."""
125-
126-
model = DummyModel()
127-
trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, max_epochs=1)
128-
trainer.fit(model)
129-
130-
assert model.configure_sharded_model_called
131-
model.configure_sharded_model_called = False
132-
133-
assert not model.configure_sharded_model_called

tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -49,27 +49,29 @@ def test_fsdp_with_sharded_amp(device_count_mock, mock_cuda_available, tmpdir):
4949

5050

5151
class TestFSDPModel(BoringModel):
52+
def __init__(self, *args, **kwargs):
53+
super().__init__(*args, **kwargs)
54+
self.layer: Optional[torch.nn.Module] = None
55+
56+
def _init_model(self) -> None:
57+
if self.layer is None:
58+
self.layer = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2))
59+
5260
def setup(self, stage: str) -> None:
53-
if stage != "fit":
54-
# when running stages like test, validate, and predict, we will skip setting up,
55-
# will directly use the module itself unless we load from checkpoint
56-
return
57-
# resetting call_configure_sharded_model_hook attribute so that we could call
58-
# configure sharded model
59-
self.call_configure_sharded_model_hook = False
60-
# for loading full state dict, we first need to create a new unwrapped model
61-
# to load state dict and then wrapping
62-
self.layer = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2))
61+
self._init_model()
6362

6463
def configure_sharded_model(self) -> None:
64+
# the model is already wrapped with FSDP: no need to wrap again!
65+
if isinstance(self.layer, FullyShardedDataParallel):
66+
return
6567
for i, layer in enumerate(self.layer):
6668
if i % 2 == 0:
6769
self.layer[i] = wrap(layer)
6870
self.layer = wrap(self.layer)
6971

7072
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
7173
# when loading full state dict, we first need to create a new unwrapped model
72-
self.setup("fit")
74+
self._init_model()
7375

7476
def configure_optimizers(self):
7577
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
@@ -131,13 +133,8 @@ def _assert_save_equality(trainer, ckpt_path, cls=TestFSDPModel):
131133
def _run_multiple_stages(trainer, model, model_path: Optional[str] = None):
132134
trainer.fit(model)
133135

134-
model_call_configure_sharded_model_hook = getattr(model, "call_configure_sharded_model_hook", False)
135-
trainer_accelerator_call_configure_sharded_model_hook = trainer.accelerator.call_configure_sharded_model_hook
136-
137136
model_path = model_path if model_path else trainer.checkpoint_callback.last_model_path
138137

139-
assert model_call_configure_sharded_model_hook
140-
assert not trainer_accelerator_call_configure_sharded_model_hook
141138
trainer.save_checkpoint(model_path, weights_only=True)
142139

143140
_assert_save_equality(trainer, model_path, cls=TestFSDPModel)

0 commit comments

Comments
 (0)