Skip to content

Commit 012fb2d

Browse files
awaelchlitchaton
andcommitted
Consolidate state when retrieving sharded state dict in Lite (#10746)
Co-authored-by: thomas chaton <[email protected]>
1 parent 15bbfac commit 012fb2d

File tree

3 files changed

+16
-1
lines changed

3 files changed

+16
-1
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2929
- Fixed signals being registered within threads ([#10610](https://github.com/PyTorchLightning/pytorch-lightning/pull/10610))
3030
- Fixed an issue with collecting logged test results with multiple dataloaders ([#10522](https://github.com/PyTorchLightning/pytorch-lightning/pull/10522))
3131

32+
- Fixed a consolidation error in Lite when attempting to save the state dict of a sharded optimizer ([#10746](https://github.com/PyTorchLightning/pytorch-lightning/pull/10746))
33+
34+
3235
- Fixed `_compare_version` for python packages ([#10762](https://github.com/PyTorchLightning/pytorch-lightning/pull/10762))
3336

3437

pytorch_lightning/lite/wrappers.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def __init__(self, optimizer: Optimizer, accelerator: Accelerator) -> None:
4646
"""
4747
# `__del__` is skipped in case the optimizer has implemented custom destructor logic which we would
4848
# not want to call on destruction of the `_LiteOptimizer
49-
self.__dict__ = {k: v for k, v in optimizer.__dict__.items() if k not in ("step", "__del__")}
49+
self.__dict__ = {k: v for k, v in optimizer.__dict__.items() if k not in ("state_dict", "step", "__del__")}
5050
self.__class__ = type("Lite" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__), {})
5151
self._optimizer = optimizer
5252
self._accelerator = accelerator
@@ -55,6 +55,9 @@ def __init__(self, optimizer: Optimizer, accelerator: Accelerator) -> None:
5555
def optimizer(self) -> Optimizer:
5656
return self._optimizer
5757

58+
def state_dict(self) -> Dict[str, Tensor]:
59+
return self._accelerator.optimizer_state(self.optimizer)
60+
5861
def step(self, closure: Optional[Callable] = None) -> None:
5962
closure = closure or _do_nothing_closure
6063
self._accelerator.optimizer_step(

tests/lite/test_wrappers.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,15 @@ def test_lite_optimizer_wraps():
142142
assert isinstance(lite_optimizer, optimizer_cls)
143143

144144

145+
def test_lite_optimizer_state_dict():
146+
"""Test that the LiteOptimizer calls into the accelerator/strategy to collect the state."""
147+
optimizer = Mock()
148+
accelerator = Mock()
149+
lite_optimizer = _LiteOptimizer(optimizer=optimizer, accelerator=accelerator)
150+
lite_optimizer.state_dict()
151+
accelerator.optimizer_state.assert_called_with(optimizer)
152+
153+
145154
def test_lite_optimizer_steps():
146155
"""Test that the LiteOptimizer forwards the step() and zero_grad() calls to the wrapped optimizer."""
147156
optimizer = Mock()

0 commit comments

Comments
 (0)