Skip to content

Commit 8538c1f

Browse files
shuyingsunshine21pre-commit-ci[bot]SeanNaren
authored
Accelerator model state dict (#7474)
* Fix some test errors Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * checkpoint consolidation * Update ddp_spawn.py * Update test_metric_result_integration.py * Update test_results.py * Update utils.py * Update utils.py * Update test_all_gather_grad.py * Update test_all_gather_grad.py * Update test_results.py * Revert "Update test_results.py" This reverts commit 9d4a2b8. * Revert "Merge pull request #1 from shuyingsunshine21/shuyingsunshine21-checkpoint_consolidate" This reverts commit c5053da, reversing changes made to 0d23d75. * Revert "Update test_all_gather_grad.py" This reverts commit 0d23d75. * Revert "Update utils.py" This reverts commit 70fe5da. * Revert "Update utils.py" This reverts commit a9aae99. * Revert "Update test_results.py" This reverts commit ea74906. * Revert "Update test_metric_result_integration.py" This reverts commit bf70e43. * Revert "Update ddp_spawn.py" This reverts commit f172101. * Revert "checkpoint consolidation" This reverts commit 536c132. * Revert "Revert "checkpoint consolidation"" This reverts commit 3a9fde9. * Revert "Revert "Revert "checkpoint consolidation""" This reverts commit 7a369f4. * Revert "Revert "Update ddp_spawn.py"" This reverts commit 8222dc9. * Revert "Revert "Update test_metric_result_integration.py"" This reverts commit 6c095b2. * Revert "Revert "Update test_results.py"" This reverts commit 250d0aa. * Revert "Revert "Update utils.py"" This reverts commit 8651d54. * Revert "Revert "Update test_all_gather_grad.py"" This reverts commit dcdcd29. * modify distributed environment to make test pass * modify model state dict to training type plugin * remove changes * add changelog * fixing isort for pre-commit failure * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Address code review Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: SeanNaren <[email protected]>
1 parent a1a655d commit 8538c1f

File tree

4 files changed

+16
-1
lines changed

4 files changed

+16
-1
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3838
- Changed `resolve_training_type_plugins` to allow setting `num_nodes` and `sync_batchnorm` from `Trainer` setting ([7026](https://github.com/PyTorchLightning/pytorch-lightning/pull/7026))
3939

4040

41+
- Changed `model.state_dict()` in `CheckpointConnector` to allow `training_type_plugin` to customize the model's `state_dict()` ([7474](https://github.com/PyTorchLightning/pytorch-lightning/pull/7474))
42+
43+
4144
### Deprecated
4245

4346

pytorch_lightning/accelerators/accelerator.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,12 @@ def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]:
420420
"""
421421
return getattr(self.training_type_plugin, 'optimizer_state', lambda x: x.state_dict())(optimizer)
422422

423+
def lightning_module_state_dict(self) -> Dict[str, Union[Any, Tensor]]:
424+
"""
425+
Returns state of model. Allows for syncing/collating model state from processes in custom plugins.
426+
"""
427+
return self.training_type_plugin.lightning_module_state_dict()
428+
423429
def on_save(self, checkpoint: Dict[str, Union[Any, Tensor]]) -> Dict[str, Union[Any, Tensor]]:
424430
return self.training_type_plugin.on_save(checkpoint)
425431

pytorch_lightning/plugins/training_type/training_type_plugin.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from typing import Any, Callable, Dict, Generator, Iterable, Optional, Tuple, TypeVar, Union
1717

1818
import torch
19+
from torch import Tensor
1920
from torch.nn import Module
2021
from torch.optim import Optimizer
2122
from torch.utils.data import DataLoader
@@ -241,6 +242,11 @@ def update_global_step(self, total_batch_idx: int, current_global_step: int) ->
241242
"""
242243
return current_global_step + 1
243244

245+
def lightning_module_state_dict(self) -> Dict[str, Union[Any, Tensor]]:
246+
"""Returns model state."""
247+
model = self.lightning_module
248+
return model.state_dict()
249+
244250
def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None:
245251
"""Save model/training states as a checkpoint file through state-dump and file-write.
246252

pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
273273
'epoch': current_epoch,
274274
'global_step': global_step,
275275
'pytorch-lightning_version': pytorch_lightning.__version__,
276-
'state_dict': model.state_dict(),
276+
'state_dict': self.trainer.accelerator.lightning_module_state_dict(),
277277
}
278278

279279
if not weights_only:

0 commit comments

Comments
 (0)