diff --git a/CHANGELOG.md b/CHANGELOG.md index efa420bebcd8c..58a8ab471d719 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -195,6 +195,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `strategy` argument to Trainer ([#8597](https://github.com/PyTorchLightning/pytorch-lightning/pull/8597)) +- Added `init_meta_context`, `materialize_module` utilities ([#9920](https://github.com/PyTorchLightning/pytorch-lightning/pull/9920)) + + - Added `TPUPrecisionPlugin` ([#10020](https://github.com/PyTorchLightning/pytorch-lightning/pull/#10020)) @@ -214,6 +217,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Implemented `DataParallelPlugin._setup_model` ([#10010](https://github.com/PyTorchLightning/pytorch-lightning/pull/10010)) * Implemented `DeepSpeedPlugin._setup_models_and_optimizers` ([#10009](https://github.com/PyTorchLightning/pytorch-lightning/pull/10009)) + ### Changed - Setting `Trainer(accelerator="ddp_cpu")` now does not spawn a subprocess if `num_processes` is kept `1` along with `num_nodes > 1` ([#9603](https://github.com/PyTorchLightning/pytorch-lightning/pull/9603)). diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index aa0ca45cabc0e..883ae74df1346 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -426,7 +426,7 @@ def _setup_model_and_optimizer( def init_deepspeed(self): # check that `configure_gradient_clipping` hook isn't overriden since deepspeed handles # gradient clipping internally - if is_overridden("configure_gradient_clipping", self.lightning_module): + if is_overridden("configure_gradient_clipping", self.lightning_module, pl.LightningModule): rank_zero_warn( "Since deepspeed handles gradient clipping internally, this hook will" " be ignored. Consider setting `gradient_clip_val` and `gradient_clip_algorithm`" diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index ce53b3cea3072..7b30d7b1c4a48 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -89,6 +89,7 @@ from pytorch_lightning.utilities.distributed import distributed_available from pytorch_lightning.utilities.exceptions import ExitGracefullyException, MisconfigurationException from pytorch_lightning.utilities.imports import _fault_tolerant_training +from pytorch_lightning.utilities.meta import materialize_module from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.seed import reset_seed from pytorch_lightning.utilities.types import ( @@ -1348,6 +1349,7 @@ def _call_setup_hook(self) -> None: def _call_configure_sharded_model(self) -> None: with self.accelerator.model_sharded_context(): + materialize_module(self.lightning_module) self.call_hook("configure_sharded_model") self.call_hook("on_configure_sharded_model") diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index b4e3d043bb8e1..69cf3ce1d4a9f 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -93,6 +93,7 @@ def _compare_version(package: str, op: Callable, version: str, use_base_version: _OMEGACONF_AVAILABLE = _module_available("omegaconf") _POPTORCH_AVAILABLE = _module_available("poptorch") _RICH_AVAILABLE = _module_available("rich") and _compare_version("rich", operator.ge, "10.2.2") +_TORCH_META_AVAILABLE = _compare_version("torch", operator.ge, "1.10.0.dev20210922") _TORCH_QUANTIZE_AVAILABLE = bool([eg for eg in torch.backends.quantized.supported_engines if eg != "none"]) _TORCHTEXT_AVAILABLE = _module_available("torchtext") _TORCHVISION_AVAILABLE = _module_available("torchvision") diff --git a/pytorch_lightning/utilities/meta.py b/pytorch_lightning/utilities/meta.py new file mode 100644 index 0000000000000..11c903a5abd65 --- /dev/null +++ b/pytorch_lightning/utilities/meta.py @@ -0,0 +1,323 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib +import inspect +import threading +from contextlib import contextmanager +from functools import partial +from itertools import chain +from types import ModuleType +from typing import Callable, Dict, Generator, Iterator, List, Optional, Set, Type + +import torch +from torch import nn, Tensor +from torch.nn import Module +from torch.nn.modules.container import ModuleDict, ModuleList, Sequential + +from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _TORCH_META_AVAILABLE + +if _TORCH_META_AVAILABLE: + from torch._C import _DisableTorchDispatch # type: ignore[attr-defined] + + #################################################################### + # BELOW: TAKEN FROM https://github.com/pytorch/pytorch/pull/66317. # + # TODO: Removed once merged and released on PyTorch side # + #################################################################### + + @contextmanager + def enable_python_mode(cls) -> Iterator[None]: + if not hasattr(cls, "__torch_dispatch__"): + raise ValueError("The class passed to enable_python_mode " "must have a __torch_dispatch__ classmethod") + if not isinstance(cls, type) or not issubclass(cls, (torch.Tensor,)): + raise ValueError("The argument passed to enable_python_mode " "must be the type of a Tensor subclass") + torch._C._enter_python_mode(cls) + try: + yield + finally: + torch._C._exit_python_mode() + + _tls = threading.local() + _tls.in_call = False + + @contextmanager + def _no_dispatch() -> Iterator[None]: + """Temporarily disables the Python dispatch mode.""" + guard = _DisableTorchDispatch() # noqa F841 + try: + yield + finally: + del guard + + def _handle_arange(func, args, kwargs): + kwargs["device"] = torch.device("cpu") + return torch.empty_like(func(*args, **kwargs), device="meta") + + def _handle_tril(func, args, kwargs): + if args and isinstance(args[0], Tensor): + return torch.empty_like(args[0], device="meta") + + return NotImplemented + + class _MetaContext(Tensor): + _op_handlers: Dict[Callable, Callable] = {} + + @classmethod + def _ensure_handlers_initialized(cls) -> None: + if cls._op_handlers: + return + + cls._op_handlers.update( + { + torch.ops.aten.arange: _handle_arange, + torch.ops.aten.tril: _handle_tril, + } + ) + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + cls._ensure_handlers_initialized() + + op_handler: Optional[Callable] + + try: + op_handler = cls._op_handlers[func] + except KeyError: + op_handler = None + + with _no_dispatch(): + if op_handler: + result = op_handler(func, args, kwargs) + if result is not NotImplemented: + return result + + if "device" in kwargs: + kwargs["device"] = torch.device("meta") + + return func(*args, **kwargs) + + def init_meta(module_fn: Callable[..., Module], *args, **kwargs) -> Module: + def create_instance(module=None) -> Module: + if module: + module.__init__(*args, **kwargs) + return module + return module_fn(*args, **kwargs) + + if _tls.in_call: + module = create_instance() + else: + _tls.in_call = True + try: + with enable_python_mode(_MetaContext): + module = create_instance() + finally: + _tls.in_call = False + + module.materialize = partial(create_instance, module=module) # type: ignore[assignment] + + return module + + def is_meta_init() -> bool: + """Indicates whether the module is being instantiated by ``init_meta()``.""" + return _tls.in_call + + #################################################################### + # ABOVE: TAKEN FROM https://github.com/pytorch/pytorch/pull/66317. # + # TODO: Removed once merged and released on PyTorch side # + #################################################################### + +else: + + def init_meta(*_, **__): + if not _TORCH_META_AVAILABLE: + return MisconfigurationException("`init_meta` is supported from PyTorch 1.10.0") + + +# https://stackoverflow.com/a/63851681/9201239 +def get_all_subclasses(cls: Type[nn.Module]) -> Set[nn.Module]: + subclass_list = [] + + def recurse(cl): + for subclass in cl.__subclasses__(): + subclass_list.append(subclass) + recurse(subclass) + + recurse(cls) + + return set(subclass_list) + + +def recursively_setattr(root_module: nn.Module, prefix: str, materialized_module: nn.Module) -> None: + *path, name = prefix.split(".") + for p in path: + root_module = getattr(root_module, p) + + try: + index = int(name) + root_module[index] = materialized_module + except ValueError: + setattr(root_module, name, materialized_module) + + +def materialize_module(root_module: nn.Module) -> nn.Module: + """This utility performs an in-place operation by materialize a module and its children.""" + if not _TORCH_META_AVAILABLE: + return root_module + + materialize_fn = getattr(root_module, "materialize", None) + if materialize_fn and not isinstance(root_module, (Sequential, ModuleList, ModuleDict)): + return materialize_fn() + + for name, child in root_module.named_children(): + materialize_fn = getattr(child, "materialize", None) + if not materialize_fn or isinstance(child, (Sequential, ModuleList, ModuleDict)): + materialize_module(child) + else: + setattr(child, name, materialize_fn()) + return root_module + + +# cache subclasses to optimize the search when resetting the meta device later on. +__STORAGE_META__ = {} + +__CREATED_MODULES__ = set() + + +def _unset_meta_device(from_created: bool = False) -> None: + """Replace all meta module by their original version.""" + if not _TORCH_META_AVAILABLE: + raise MisconfigurationException("`init_meta` is supported from PyTorch 1.10.0") + + if from_created: + values = [__STORAGE_META__[key] for key in __CREATED_MODULES__] + else: + values = __STORAGE_META__.values() + + for mods, subclass, _ in values: + for mod in mods: + setattr(mod, subclass.__name__, subclass) + + +def _set_meta_device_populated(from_created: bool = False) -> None: + """Replace all meta module by their original version.""" + if not _TORCH_META_AVAILABLE: + raise MisconfigurationException("`init_meta` is supported from PyTorch 1.10.0") + + if from_created: + values = [__STORAGE_META__[key] for key in __CREATED_MODULES__] + else: + values = __STORAGE_META__.values() + + for mods, subclass, meta_class in values: + for mod in mods: + setattr(mod, subclass.__name__, meta_class) + + +def _set_meta_device() -> None: + """Replace all torch.nn.Module by their meta replacement.""" + + if not _TORCH_META_AVAILABLE: + raise MisconfigurationException("`init_meta` is supported from PyTorch 1.10.0") + + # Author note: This can be optimized further by searching all subclasses at once. + # Its time complexity is O(n*m) where n is the number of all subclasses if there's no multiple inheritance + # and m the number of all subclasses belonging to its subclass module. + + for subclass in get_all_subclasses(torch.nn.modules.module.Module): + + if isinstance(subclass, (Sequential, ModuleList, ModuleDict)): + continue + + # if a subclass has already been stored, we should use the cache + if str(subclass) in __STORAGE_META__: + # reset the class import package to its rightfull state. + mods, subclass, meta_class = __STORAGE_META__[subclass] + for mod in mods: + setattr(mod, subclass.__name__, meta_class) + continue + + # Create a class subclassing current `subclass` overriding its new method. + # this will enable use to use `torch.distributed.nn.utils.init_meta` to create a `meta` + # version of the current subclass module + class _MetaClass(subclass): + @classmethod + @contextmanager + def instantiation_context(cls, materialize: bool): + _unset_meta_device(from_created=True) + yield + _set_meta_device_populated(from_created=True) + + @classmethod + def materialize(cls, materialize_fn: Callable): + with cls.instantiation_context(materialize=True): + obj = materialize_fn() + return obj + + @staticmethod + def add_subclasses(subclass): + """This is used to unrol the instantion tree while creating the modules.""" + __CREATED_MODULES__.add(subclass) + if subclass.__bases__[0] != torch.nn.modules.module.Module: + _MetaClass.add_subclasses(subclass.__bases__[0]) + + def __new__(cls, *args, **kwargs): + subclass = cls.__bases__[0] + cls.add_subclasses(subclass) + with cls.instantiation_context(materialize=False): + obj = init_meta(subclass, *args, **kwargs) + + obj.materialize = partial(cls.materialize, materialize_fn=obj.materialize) + return obj + + def search(mod: ModuleType) -> List[ModuleType]: + out = [] + for _, obj in inspect.getmembers(mod): + if obj == subclass: + out.append(mod) + return out + + submodules = subclass.__module__.split(".") + mod = importlib.import_module(submodules[0]) + + # nn.Module class can be imported at different level and they all need to be mocked. + # Example: torch.nn.Linear is actually torch.nn.modules.linear.Linear + # Therefore, torch.nn.Linear, torch.nn.modules.Linear, torch.nn.modules.linear.Linear + # needs to be replaced by the torch.nn.linear.modules.Linear _MetaClass + out = [] + out.append(search(mod)) + for name in submodules[1:]: + mod = getattr(mod, name) + out.append(search(mod)) + + # drop empty module + mods = [mod for mod in chain(*out) if mod] + + # store the modules search so it doesn't have to be performed again for this class + __STORAGE_META__[subclass] = (mods, subclass, _MetaClass) + + # replace all subclass by its meta form + for mod in mods: + setattr(mod, subclass.__name__, _MetaClass) + + +@contextmanager +def init_meta_context() -> Generator: + rank_zero_warn( + "Be aware this feature is highly experimental and there are a number of weird edge cases " + "where it can internal assert and/or crash. A more stable version is to be expected from PyTorch 1.11." + ) + _set_meta_device() + yield + _unset_meta_device() diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index fad764b896ff3..7c0623323f6f1 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -17,7 +17,8 @@ from pytorch_lightning.plugins import DeepSpeedPlugin, DeepSpeedPrecisionPlugin from pytorch_lightning.plugins.training_type.deepspeed import LightningDeepSpeedModule from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE +from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE, _TORCH_META_AVAILABLE +from pytorch_lightning.utilities.meta import init_meta_context from tests.helpers.boring_model import BoringModel, RandomDataset, RandomIterableDataset from tests.helpers.datamodules import ClassifDataModule from tests.helpers.runif import RunIf @@ -1044,3 +1045,16 @@ def on_test_batch_start( ) trainer.fit(model) trainer.test(model) + + +@pytest.mark.skipif(not _TORCH_META_AVAILABLE, reason="the meta device context is supported from PyTorch 1.10.") +@RunIf(min_gpus=2, deepspeed=True, special=True) +def test_deepspeed_with_meta_device(tmpdir): + with init_meta_context(): + model = BoringModel() + assert model.layer.weight.device.type == "meta" + trainer = Trainer( + default_root_dir=tmpdir, plugins=[DeepSpeedPlugin(stage=3)], gpus=2, fast_dev_run=True, precision=16 + ) + trainer.fit(model) + assert model.layer.weight.device.type == "cpu" diff --git a/tests/utilities/test_meta.py b/tests/utilities/test_meta.py new file mode 100644 index 0000000000000..efcca45a7483e --- /dev/null +++ b/tests/utilities/test_meta.py @@ -0,0 +1,66 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +from torch import nn + +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.utilities.imports import _TORCH_META_AVAILABLE +from pytorch_lightning.utilities.meta import init_meta_context, materialize_module + + +class MLP(nn.Module): + def __init__(self, num_layers: int): + super().__init__() + self.layer = nn.Sequential(*[nn.Linear(1, 1) for _ in range(num_layers)] + [nn.Dropout(), nn.LayerNorm(1)]) + + +class BoringModel(LightningModule): + def __init__(self, num_layers: int): + super().__init__() + self.save_hyperparameters() + self.layer = nn.Sequential(*[nn.Linear(1, 1) for _ in range(self.hparams.num_layers)]) + + +@pytest.mark.skipif(not _TORCH_META_AVAILABLE, reason="Support only with PyTorch 1.10") +def test_init_meta_context(): + + with init_meta_context(): + m = nn.Linear(in_features=1, out_features=1) + assert m.weight.device.type == "meta" + mlp = MLP(4) + assert mlp.layer[0].weight.device.type == "meta" + + mlp = materialize_module(mlp) + assert mlp.layer[0].weight.device.type == "cpu" + + model = BoringModel(4) + assert model.layer[0].weight.device.type == "meta" + materialize_module(model) + assert model.layer[0].weight.device.type == "cpu" + + mlp = MLP(4) + assert mlp.layer[0].weight.device.type == "cpu" + # no-op as already materialized. + materialize_module(mlp) + assert mlp.layer[0].weight.device.type == "cpu" + + m = nn.Linear(in_features=1, out_features=1) + assert m.weight.device.type == "cpu" + + with init_meta_context(): + m = nn.Linear(in_features=1, out_features=1) + assert m.weight.device.type == "meta" + + m = nn.Linear(in_features=1, out_features=1) + assert m.weight.device.type == "cpu"