-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Add support for init_meta_context, materialize_module #9920
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
42 commits
Select commit
Hold shift + click to select a range
9a8954e
update
tchaton 36bb238
update
tchaton f1890bc
remove credit
tchaton 103c311
update
tchaton f346120
update
tchaton 8f7fc11
update
tchaton 3d7852f
add changelog
tchaton feb6c9c
update
tchaton 0cdbec2
update on comments
tchaton 8c0402b
update changelog
tchaton ad0f3ba
typo
tchaton 73c0588
update
tchaton ff41479
update
tchaton 402e6f6
update
tchaton e0d4c5b
update
tchaton 57f4ec0
update changelog
tchaton 1b5fb68
update
tchaton 11a3eb9
add note
tchaton e116e78
update
tchaton 0f8fb06
Merge branch 'set_meta_device' of https://github.com/PyTorchLightning…
tchaton ee15d11
update test name
tchaton f8d2e9e
wip
tchaton 0318480
update
tchaton 78744bc
add some typing
tchaton 0bd6b72
update on comments
tchaton 92b5a63
resolve bug
tchaton 7661b1b
add layernorm
tchaton f78db68
update
tchaton 5eeec6a
revert back
tchaton a03cd69
replace the in_place
tchaton f28673c
remove extra lines
tchaton 43b62ee
update
tchaton 0595843
remove list
tchaton 8b27b15
update
tchaton 0850f1e
update
tchaton e3f991b
update
tchaton cfb42a2
add a warning about unstability
tchaton 50357b2
add a warning about unstability
tchaton df531aa
update test
tchaton 50e9d65
Merge branch 'master' into set_meta_device
tchaton 0afb695
revert on previous api based on can comments
tchaton 2d8c0a1
Merge branch 'master' into set_meta_device
tchaton File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,323 @@ | ||
# Copyright The PyTorch Lightning team. | ||
tchaton marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import importlib | ||
import inspect | ||
import threading | ||
from contextlib import contextmanager | ||
from functools import partial | ||
from itertools import chain | ||
from types import ModuleType | ||
from typing import Callable, Dict, Generator, Iterator, List, Optional, Set, Type | ||
|
||
import torch | ||
from torch import nn, Tensor | ||
from torch.nn import Module | ||
from torch.nn.modules.container import ModuleDict, ModuleList, Sequential | ||
|
||
from pytorch_lightning.utilities import rank_zero_warn | ||
from pytorch_lightning.utilities.exceptions import MisconfigurationException | ||
from pytorch_lightning.utilities.imports import _TORCH_META_AVAILABLE | ||
|
||
if _TORCH_META_AVAILABLE: | ||
from torch._C import _DisableTorchDispatch # type: ignore[attr-defined] | ||
|
||
#################################################################### | ||
# BELOW: TAKEN FROM https://github.com/pytorch/pytorch/pull/66317. # | ||
# TODO: Removed once merged and released on PyTorch side # | ||
#################################################################### | ||
|
||
@contextmanager | ||
def enable_python_mode(cls) -> Iterator[None]: | ||
if not hasattr(cls, "__torch_dispatch__"): | ||
raise ValueError("The class passed to enable_python_mode " "must have a __torch_dispatch__ classmethod") | ||
if not isinstance(cls, type) or not issubclass(cls, (torch.Tensor,)): | ||
raise ValueError("The argument passed to enable_python_mode " "must be the type of a Tensor subclass") | ||
torch._C._enter_python_mode(cls) | ||
try: | ||
yield | ||
finally: | ||
torch._C._exit_python_mode() | ||
tchaton marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
_tls = threading.local() | ||
_tls.in_call = False | ||
|
||
@contextmanager | ||
def _no_dispatch() -> Iterator[None]: | ||
"""Temporarily disables the Python dispatch mode.""" | ||
guard = _DisableTorchDispatch() # noqa F841 | ||
try: | ||
yield | ||
finally: | ||
del guard | ||
|
||
def _handle_arange(func, args, kwargs): | ||
kwargs["device"] = torch.device("cpu") | ||
return torch.empty_like(func(*args, **kwargs), device="meta") | ||
|
||
def _handle_tril(func, args, kwargs): | ||
if args and isinstance(args[0], Tensor): | ||
return torch.empty_like(args[0], device="meta") | ||
|
||
return NotImplemented | ||
|
||
class _MetaContext(Tensor): | ||
_op_handlers: Dict[Callable, Callable] = {} | ||
|
||
@classmethod | ||
def _ensure_handlers_initialized(cls) -> None: | ||
if cls._op_handlers: | ||
return | ||
|
||
cls._op_handlers.update( | ||
{ | ||
torch.ops.aten.arange: _handle_arange, | ||
torch.ops.aten.tril: _handle_tril, | ||
} | ||
) | ||
|
||
@classmethod | ||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): | ||
cls._ensure_handlers_initialized() | ||
|
||
op_handler: Optional[Callable] | ||
|
||
try: | ||
op_handler = cls._op_handlers[func] | ||
except KeyError: | ||
op_handler = None | ||
|
||
with _no_dispatch(): | ||
if op_handler: | ||
result = op_handler(func, args, kwargs) | ||
if result is not NotImplemented: | ||
return result | ||
|
||
if "device" in kwargs: | ||
kwargs["device"] = torch.device("meta") | ||
|
||
return func(*args, **kwargs) | ||
|
||
def init_meta(module_fn: Callable[..., Module], *args, **kwargs) -> Module: | ||
def create_instance(module=None) -> Module: | ||
if module: | ||
module.__init__(*args, **kwargs) | ||
return module | ||
return module_fn(*args, **kwargs) | ||
|
||
if _tls.in_call: | ||
module = create_instance() | ||
else: | ||
_tls.in_call = True | ||
try: | ||
with enable_python_mode(_MetaContext): | ||
module = create_instance() | ||
finally: | ||
_tls.in_call = False | ||
|
||
module.materialize = partial(create_instance, module=module) # type: ignore[assignment] | ||
|
||
return module | ||
|
||
def is_meta_init() -> bool: | ||
"""Indicates whether the module is being instantiated by ``init_meta()``.""" | ||
return _tls.in_call | ||
|
||
#################################################################### | ||
# ABOVE: TAKEN FROM https://github.com/pytorch/pytorch/pull/66317. # | ||
# TODO: Removed once merged and released on PyTorch side # | ||
#################################################################### | ||
|
||
else: | ||
|
||
def init_meta(*_, **__): | ||
if not _TORCH_META_AVAILABLE: | ||
return MisconfigurationException("`init_meta` is supported from PyTorch 1.10.0") | ||
|
||
|
||
# https://stackoverflow.com/a/63851681/9201239 | ||
def get_all_subclasses(cls: Type[nn.Module]) -> Set[nn.Module]: | ||
subclass_list = [] | ||
|
||
def recurse(cl): | ||
for subclass in cl.__subclasses__(): | ||
subclass_list.append(subclass) | ||
recurse(subclass) | ||
|
||
recurse(cls) | ||
|
||
return set(subclass_list) | ||
|
||
|
||
def recursively_setattr(root_module: nn.Module, prefix: str, materialized_module: nn.Module) -> None: | ||
*path, name = prefix.split(".") | ||
for p in path: | ||
root_module = getattr(root_module, p) | ||
|
||
try: | ||
index = int(name) | ||
root_module[index] = materialized_module | ||
except ValueError: | ||
setattr(root_module, name, materialized_module) | ||
|
||
|
||
def materialize_module(root_module: nn.Module) -> nn.Module: | ||
"""This utility performs an in-place operation by materialize a module and its children.""" | ||
if not _TORCH_META_AVAILABLE: | ||
return root_module | ||
tchaton marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
materialize_fn = getattr(root_module, "materialize", None) | ||
if materialize_fn and not isinstance(root_module, (Sequential, ModuleList, ModuleDict)): | ||
return materialize_fn() | ||
|
||
for name, child in root_module.named_children(): | ||
materialize_fn = getattr(child, "materialize", None) | ||
if not materialize_fn or isinstance(child, (Sequential, ModuleList, ModuleDict)): | ||
materialize_module(child) | ||
else: | ||
setattr(child, name, materialize_fn()) | ||
return root_module | ||
|
||
|
||
# cache subclasses to optimize the search when resetting the meta device later on. | ||
__STORAGE_META__ = {} | ||
|
||
__CREATED_MODULES__ = set() | ||
|
||
|
||
def _unset_meta_device(from_created: bool = False) -> None: | ||
"""Replace all meta module by their original version.""" | ||
if not _TORCH_META_AVAILABLE: | ||
raise MisconfigurationException("`init_meta` is supported from PyTorch 1.10.0") | ||
|
||
if from_created: | ||
values = [__STORAGE_META__[key] for key in __CREATED_MODULES__] | ||
else: | ||
values = __STORAGE_META__.values() | ||
|
||
for mods, subclass, _ in values: | ||
for mod in mods: | ||
setattr(mod, subclass.__name__, subclass) | ||
|
||
|
||
def _set_meta_device_populated(from_created: bool = False) -> None: | ||
"""Replace all meta module by their original version.""" | ||
if not _TORCH_META_AVAILABLE: | ||
raise MisconfigurationException("`init_meta` is supported from PyTorch 1.10.0") | ||
|
||
if from_created: | ||
values = [__STORAGE_META__[key] for key in __CREATED_MODULES__] | ||
else: | ||
values = __STORAGE_META__.values() | ||
|
||
for mods, subclass, meta_class in values: | ||
for mod in mods: | ||
setattr(mod, subclass.__name__, meta_class) | ||
|
||
|
||
def _set_meta_device() -> None: | ||
"""Replace all torch.nn.Module by their meta replacement.""" | ||
|
||
if not _TORCH_META_AVAILABLE: | ||
raise MisconfigurationException("`init_meta` is supported from PyTorch 1.10.0") | ||
|
||
# Author note: This can be optimized further by searching all subclasses at once. | ||
# Its time complexity is O(n*m) where n is the number of all subclasses if there's no multiple inheritance | ||
# and m the number of all subclasses belonging to its subclass module. | ||
|
||
for subclass in get_all_subclasses(torch.nn.modules.module.Module): | ||
|
||
if isinstance(subclass, (Sequential, ModuleList, ModuleDict)): | ||
continue | ||
|
||
# if a subclass has already been stored, we should use the cache | ||
if str(subclass) in __STORAGE_META__: | ||
# reset the class import package to its rightfull state. | ||
mods, subclass, meta_class = __STORAGE_META__[subclass] | ||
for mod in mods: | ||
setattr(mod, subclass.__name__, meta_class) | ||
continue | ||
|
||
# Create a class subclassing current `subclass` overriding its new method. | ||
# this will enable use to use `torch.distributed.nn.utils.init_meta` to create a `meta` | ||
# version of the current subclass module | ||
class _MetaClass(subclass): | ||
@classmethod | ||
@contextmanager | ||
def instantiation_context(cls, materialize: bool): | ||
_unset_meta_device(from_created=True) | ||
yield | ||
_set_meta_device_populated(from_created=True) | ||
|
||
@classmethod | ||
def materialize(cls, materialize_fn: Callable): | ||
with cls.instantiation_context(materialize=True): | ||
obj = materialize_fn() | ||
return obj | ||
|
||
@staticmethod | ||
def add_subclasses(subclass): | ||
"""This is used to unrol the instantion tree while creating the modules.""" | ||
__CREATED_MODULES__.add(subclass) | ||
if subclass.__bases__[0] != torch.nn.modules.module.Module: | ||
_MetaClass.add_subclasses(subclass.__bases__[0]) | ||
|
||
def __new__(cls, *args, **kwargs): | ||
subclass = cls.__bases__[0] | ||
cls.add_subclasses(subclass) | ||
with cls.instantiation_context(materialize=False): | ||
obj = init_meta(subclass, *args, **kwargs) | ||
|
||
obj.materialize = partial(cls.materialize, materialize_fn=obj.materialize) | ||
return obj | ||
|
||
def search(mod: ModuleType) -> List[ModuleType]: | ||
out = [] | ||
for _, obj in inspect.getmembers(mod): | ||
if obj == subclass: | ||
out.append(mod) | ||
return out | ||
|
||
submodules = subclass.__module__.split(".") | ||
mod = importlib.import_module(submodules[0]) | ||
|
||
# nn.Module class can be imported at different level and they all need to be mocked. | ||
# Example: torch.nn.Linear is actually torch.nn.modules.linear.Linear | ||
# Therefore, torch.nn.Linear, torch.nn.modules.Linear, torch.nn.modules.linear.Linear | ||
# needs to be replaced by the torch.nn.linear.modules.Linear _MetaClass | ||
out = [] | ||
out.append(search(mod)) | ||
for name in submodules[1:]: | ||
mod = getattr(mod, name) | ||
out.append(search(mod)) | ||
|
||
# drop empty module | ||
mods = [mod for mod in chain(*out) if mod] | ||
|
||
# store the modules search so it doesn't have to be performed again for this class | ||
__STORAGE_META__[subclass] = (mods, subclass, _MetaClass) | ||
|
||
# replace all subclass by its meta form | ||
for mod in mods: | ||
setattr(mod, subclass.__name__, _MetaClass) | ||
|
||
|
||
@contextmanager | ||
def init_meta_context() -> Generator: | ||
tchaton marked this conversation as resolved.
Show resolved
Hide resolved
|
||
rank_zero_warn( | ||
"Be aware this feature is highly experimental and there are a number of weird edge cases " | ||
"where it can internal assert and/or crash. A more stable version is to be expected from PyTorch 1.11." | ||
) | ||
_set_meta_device() | ||
yield | ||
_unset_meta_device() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.