Skip to content

Add support for init_meta_context, materialize_module #9920

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 42 commits into from
Oct 21, 2021
Merged
Show file tree
Hide file tree
Changes from 40 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
9a8954e
update
tchaton Oct 13, 2021
36bb238
update
tchaton Oct 13, 2021
f1890bc
remove credit
tchaton Oct 13, 2021
103c311
update
tchaton Oct 14, 2021
f346120
update
tchaton Oct 14, 2021
8f7fc11
update
tchaton Oct 14, 2021
3d7852f
add changelog
tchaton Oct 14, 2021
feb6c9c
update
tchaton Oct 14, 2021
0cdbec2
update on comments
tchaton Oct 14, 2021
8c0402b
update changelog
tchaton Oct 14, 2021
ad0f3ba
typo
tchaton Oct 14, 2021
73c0588
update
tchaton Oct 14, 2021
ff41479
update
tchaton Oct 14, 2021
402e6f6
update
tchaton Oct 15, 2021
e0d4c5b
update
tchaton Oct 15, 2021
57f4ec0
update changelog
tchaton Oct 15, 2021
1b5fb68
update
tchaton Oct 15, 2021
11a3eb9
add note
tchaton Oct 15, 2021
e116e78
update
tchaton Oct 15, 2021
0f8fb06
Merge branch 'set_meta_device' of https://github.com/PyTorchLightning…
tchaton Oct 15, 2021
ee15d11
update test name
tchaton Oct 15, 2021
f8d2e9e
wip
tchaton Oct 15, 2021
0318480
update
tchaton Oct 15, 2021
78744bc
add some typing
tchaton Oct 15, 2021
0bd6b72
update on comments
tchaton Oct 15, 2021
92b5a63
resolve bug
tchaton Oct 15, 2021
7661b1b
add layernorm
tchaton Oct 15, 2021
f78db68
update
tchaton Oct 15, 2021
5eeec6a
revert back
tchaton Oct 15, 2021
a03cd69
replace the in_place
tchaton Oct 15, 2021
f28673c
remove extra lines
tchaton Oct 15, 2021
43b62ee
update
tchaton Oct 15, 2021
0595843
remove list
tchaton Oct 15, 2021
8b27b15
update
tchaton Oct 15, 2021
0850f1e
update
tchaton Oct 15, 2021
e3f991b
update
tchaton Oct 16, 2021
cfb42a2
add a warning about unstability
tchaton Oct 16, 2021
50357b2
add a warning about unstability
tchaton Oct 16, 2021
df531aa
update test
tchaton Oct 16, 2021
50e9d65
Merge branch 'master' into set_meta_device
tchaton Oct 19, 2021
0afb695
revert on previous api based on can comments
tchaton Oct 20, 2021
2d8c0a1
Merge branch 'master' into set_meta_device
tchaton Oct 20, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `strategy` argument to Trainer ([#8597](https://github.com/PyTorchLightning/pytorch-lightning/pull/8597))


- Added `init_meta_context`, `materialize_module` utilities ([#9920](https://github.com/PyTorchLightning/pytorch-lightning/pull/9920))


- Added `TPUPrecisionPlugin` ([#10020](https://github.com/PyTorchLightning/pytorch-lightning/pull/#10020))


Expand All @@ -208,6 +211,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Added `TrainingTypePlugin.{_setup_model, _setup_optimizer}` methods ([#9994](https://github.com/PyTorchLightning/pytorch-lightning/pull/9994))
* Implemented `DataParallelPlugin._setup_model` ([#10010](https://github.com/PyTorchLightning/pytorch-lightning/pull/10010))


### Changed

- Setting `Trainer(accelerator="ddp_cpu")` now does not spawn a subprocess if `num_processes` is kept `1` along with `num_nodes > 1` ([#9603](https://github.com/PyTorchLightning/pytorch-lightning/pull/9603)).
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def pre_dispatch(self):
def init_deepspeed(self):
# check that `configure_gradient_clipping` hook isn't overriden since deepspeed handles
# gradient clipping internally
if is_overridden("configure_gradient_clipping", self.lightning_module):
if is_overridden("configure_gradient_clipping", self.lightning_module, pl.LightningModule):
rank_zero_warn(
"Since deepspeed handles gradient clipping internally, this hook will"
" be ignored. Consider setting `gradient_clip_val` and `gradient_clip_algorithm`"
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
from pytorch_lightning.utilities.distributed import distributed_available
from pytorch_lightning.utilities.exceptions import ExitGracefullyException, MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from pytorch_lightning.utilities.meta import materialize_module
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import (
Expand Down Expand Up @@ -1349,6 +1350,7 @@ def _call_setup_hook(self) -> None:

def _call_configure_sharded_model(self) -> None:
with self.accelerator.model_sharded_context():
materialize_module(self.lightning_module)
self.call_hook("configure_sharded_model")
self.call_hook("on_configure_sharded_model")

Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def _compare_version(package: str, op: Callable, version: str, use_base_version:
_TORCH_BFLOAT_AVAILABLE = _compare_version(
"torch", operator.ge, "1.10.0.dev20210902"
) # todo: swap to 1.10.0 once released
_TORCH_META_AVAILABLE = _compare_version("torch", operator.ge, "1.10.0.dev20210922")
_TORCH_QUANTIZE_AVAILABLE = bool([eg for eg in torch.backends.quantized.supported_engines if eg != "none"])
_TORCH_SHARDED_TENSOR_AVAILABLE = _compare_version(
"torch", operator.ge, "1.10.0.dev20210809"
Expand Down
Loading