Skip to content

Commit 454e93b

Browse files
authored
Add support for init_meta_context, materialize_module (#9920)
1 parent 4ea72a9 commit 454e93b

File tree

7 files changed

+412
-2
lines changed

7 files changed

+412
-2
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
195195
- Added `strategy` argument to Trainer ([#8597](https://github.com/PyTorchLightning/pytorch-lightning/pull/8597))
196196

197197

198+
- Added `init_meta_context`, `materialize_module` utilities ([#9920](https://github.com/PyTorchLightning/pytorch-lightning/pull/9920))
199+
200+
198201
- Added `TPUPrecisionPlugin` ([#10020](https://github.com/PyTorchLightning/pytorch-lightning/pull/#10020))
199202

200203

@@ -221,6 +224,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
221224
- Added `XLACheckpointIO` plugin ([#9972](https://github.com/PyTorchLightning/pytorch-lightning/pull/9972))
222225

223226

227+
224228
### Changed
225229

226230
- Setting `Trainer(accelerator="ddp_cpu")` now does not spawn a subprocess if `num_processes` is kept `1` along with `num_nodes > 1` ([#9603](https://github.com/PyTorchLightning/pytorch-lightning/pull/9603)).

pytorch_lightning/plugins/training_type/deepspeed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,7 @@ def _setup_model_and_optimizer(
426426
def init_deepspeed(self):
427427
# check that `configure_gradient_clipping` hook isn't overriden since deepspeed handles
428428
# gradient clipping internally
429-
if is_overridden("configure_gradient_clipping", self.lightning_module):
429+
if is_overridden("configure_gradient_clipping", self.lightning_module, pl.LightningModule):
430430
rank_zero_warn(
431431
"Since deepspeed handles gradient clipping internally, this hook will"
432432
" be ignored. Consider setting `gradient_clip_val` and `gradient_clip_algorithm`"

pytorch_lightning/trainer/trainer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@
8989
from pytorch_lightning.utilities.distributed import distributed_available
9090
from pytorch_lightning.utilities.exceptions import ExitGracefullyException, MisconfigurationException
9191
from pytorch_lightning.utilities.imports import _fault_tolerant_training
92+
from pytorch_lightning.utilities.meta import materialize_module
9293
from pytorch_lightning.utilities.model_helpers import is_overridden
9394
from pytorch_lightning.utilities.seed import reset_seed
9495
from pytorch_lightning.utilities.types import (
@@ -1349,6 +1350,7 @@ def _call_setup_hook(self) -> None:
13491350

13501351
def _call_configure_sharded_model(self) -> None:
13511352
with self.accelerator.model_sharded_context():
1353+
materialize_module(self.lightning_module)
13521354
self.call_hook("configure_sharded_model")
13531355
self.call_hook("on_configure_sharded_model")
13541356

pytorch_lightning/utilities/imports.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def _compare_version(package: str, op: Callable, version: str, use_base_version:
9393
_OMEGACONF_AVAILABLE = _module_available("omegaconf")
9494
_POPTORCH_AVAILABLE = _module_available("poptorch")
9595
_RICH_AVAILABLE = _module_available("rich") and _compare_version("rich", operator.ge, "10.2.2")
96+
_TORCH_META_AVAILABLE = _compare_version("torch", operator.ge, "1.10.0.dev20210922")
9697
_TORCH_QUANTIZE_AVAILABLE = bool([eg for eg in torch.backends.quantized.supported_engines if eg != "none"])
9798
_TORCHTEXT_AVAILABLE = _module_available("torchtext")
9899
_TORCHVISION_AVAILABLE = _module_available("torchvision")

pytorch_lightning/utilities/meta.py

Lines changed: 323 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,323 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import importlib
15+
import inspect
16+
import threading
17+
from contextlib import contextmanager
18+
from functools import partial
19+
from itertools import chain
20+
from types import ModuleType
21+
from typing import Callable, Dict, Generator, Iterator, List, Optional, Set, Type
22+
23+
import torch
24+
from torch import nn, Tensor
25+
from torch.nn import Module
26+
from torch.nn.modules.container import ModuleDict, ModuleList, Sequential
27+
28+
from pytorch_lightning.utilities import rank_zero_warn
29+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
30+
from pytorch_lightning.utilities.imports import _TORCH_META_AVAILABLE
31+
32+
if _TORCH_META_AVAILABLE:
33+
from torch._C import _DisableTorchDispatch # type: ignore[attr-defined]
34+
35+
####################################################################
36+
# BELOW: TAKEN FROM https://github.com/pytorch/pytorch/pull/66317. #
37+
# TODO: Removed once merged and released on PyTorch side #
38+
####################################################################
39+
40+
@contextmanager
41+
def enable_python_mode(cls) -> Iterator[None]:
42+
if not hasattr(cls, "__torch_dispatch__"):
43+
raise ValueError("The class passed to enable_python_mode " "must have a __torch_dispatch__ classmethod")
44+
if not isinstance(cls, type) or not issubclass(cls, (torch.Tensor,)):
45+
raise ValueError("The argument passed to enable_python_mode " "must be the type of a Tensor subclass")
46+
torch._C._enter_python_mode(cls)
47+
try:
48+
yield
49+
finally:
50+
torch._C._exit_python_mode()
51+
52+
_tls = threading.local()
53+
_tls.in_call = False
54+
55+
@contextmanager
56+
def _no_dispatch() -> Iterator[None]:
57+
"""Temporarily disables the Python dispatch mode."""
58+
guard = _DisableTorchDispatch() # noqa F841
59+
try:
60+
yield
61+
finally:
62+
del guard
63+
64+
def _handle_arange(func, args, kwargs):
65+
kwargs["device"] = torch.device("cpu")
66+
return torch.empty_like(func(*args, **kwargs), device="meta")
67+
68+
def _handle_tril(func, args, kwargs):
69+
if args and isinstance(args[0], Tensor):
70+
return torch.empty_like(args[0], device="meta")
71+
72+
return NotImplemented
73+
74+
class _MetaContext(Tensor):
75+
_op_handlers: Dict[Callable, Callable] = {}
76+
77+
@classmethod
78+
def _ensure_handlers_initialized(cls) -> None:
79+
if cls._op_handlers:
80+
return
81+
82+
cls._op_handlers.update(
83+
{
84+
torch.ops.aten.arange: _handle_arange,
85+
torch.ops.aten.tril: _handle_tril,
86+
}
87+
)
88+
89+
@classmethod
90+
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
91+
cls._ensure_handlers_initialized()
92+
93+
op_handler: Optional[Callable]
94+
95+
try:
96+
op_handler = cls._op_handlers[func]
97+
except KeyError:
98+
op_handler = None
99+
100+
with _no_dispatch():
101+
if op_handler:
102+
result = op_handler(func, args, kwargs)
103+
if result is not NotImplemented:
104+
return result
105+
106+
if "device" in kwargs:
107+
kwargs["device"] = torch.device("meta")
108+
109+
return func(*args, **kwargs)
110+
111+
def init_meta(module_fn: Callable[..., Module], *args, **kwargs) -> Module:
112+
def create_instance(module=None) -> Module:
113+
if module:
114+
module.__init__(*args, **kwargs)
115+
return module
116+
return module_fn(*args, **kwargs)
117+
118+
if _tls.in_call:
119+
module = create_instance()
120+
else:
121+
_tls.in_call = True
122+
try:
123+
with enable_python_mode(_MetaContext):
124+
module = create_instance()
125+
finally:
126+
_tls.in_call = False
127+
128+
module.materialize = partial(create_instance, module=module) # type: ignore[assignment]
129+
130+
return module
131+
132+
def is_meta_init() -> bool:
133+
"""Indicates whether the module is being instantiated by ``init_meta()``."""
134+
return _tls.in_call
135+
136+
####################################################################
137+
# ABOVE: TAKEN FROM https://github.com/pytorch/pytorch/pull/66317. #
138+
# TODO: Removed once merged and released on PyTorch side #
139+
####################################################################
140+
141+
else:
142+
143+
def init_meta(*_, **__):
144+
if not _TORCH_META_AVAILABLE:
145+
return MisconfigurationException("`init_meta` is supported from PyTorch 1.10.0")
146+
147+
148+
# https://stackoverflow.com/a/63851681/9201239
149+
def get_all_subclasses(cls: Type[nn.Module]) -> Set[nn.Module]:
150+
subclass_list = []
151+
152+
def recurse(cl):
153+
for subclass in cl.__subclasses__():
154+
subclass_list.append(subclass)
155+
recurse(subclass)
156+
157+
recurse(cls)
158+
159+
return set(subclass_list)
160+
161+
162+
def recursively_setattr(root_module: nn.Module, prefix: str, materialized_module: nn.Module) -> None:
163+
*path, name = prefix.split(".")
164+
for p in path:
165+
root_module = getattr(root_module, p)
166+
167+
try:
168+
index = int(name)
169+
root_module[index] = materialized_module
170+
except ValueError:
171+
setattr(root_module, name, materialized_module)
172+
173+
174+
def materialize_module(root_module: nn.Module) -> nn.Module:
175+
"""This utility performs an in-place operation by materialize a module and its children."""
176+
if not _TORCH_META_AVAILABLE:
177+
return root_module
178+
179+
materialize_fn = getattr(root_module, "materialize", None)
180+
if materialize_fn and not isinstance(root_module, (Sequential, ModuleList, ModuleDict)):
181+
return materialize_fn()
182+
183+
for name, child in root_module.named_children():
184+
materialize_fn = getattr(child, "materialize", None)
185+
if not materialize_fn or isinstance(child, (Sequential, ModuleList, ModuleDict)):
186+
materialize_module(child)
187+
else:
188+
setattr(child, name, materialize_fn())
189+
return root_module
190+
191+
192+
# cache subclasses to optimize the search when resetting the meta device later on.
193+
__STORAGE_META__ = {}
194+
195+
__CREATED_MODULES__ = set()
196+
197+
198+
def _unset_meta_device(from_created: bool = False) -> None:
199+
"""Replace all meta module by their original version."""
200+
if not _TORCH_META_AVAILABLE:
201+
raise MisconfigurationException("`init_meta` is supported from PyTorch 1.10.0")
202+
203+
if from_created:
204+
values = [__STORAGE_META__[key] for key in __CREATED_MODULES__]
205+
else:
206+
values = __STORAGE_META__.values()
207+
208+
for mods, subclass, _ in values:
209+
for mod in mods:
210+
setattr(mod, subclass.__name__, subclass)
211+
212+
213+
def _set_meta_device_populated(from_created: bool = False) -> None:
214+
"""Replace all meta module by their original version."""
215+
if not _TORCH_META_AVAILABLE:
216+
raise MisconfigurationException("`init_meta` is supported from PyTorch 1.10.0")
217+
218+
if from_created:
219+
values = [__STORAGE_META__[key] for key in __CREATED_MODULES__]
220+
else:
221+
values = __STORAGE_META__.values()
222+
223+
for mods, subclass, meta_class in values:
224+
for mod in mods:
225+
setattr(mod, subclass.__name__, meta_class)
226+
227+
228+
def _set_meta_device() -> None:
229+
"""Replace all torch.nn.Module by their meta replacement."""
230+
231+
if not _TORCH_META_AVAILABLE:
232+
raise MisconfigurationException("`init_meta` is supported from PyTorch 1.10.0")
233+
234+
# Author note: This can be optimized further by searching all subclasses at once.
235+
# Its time complexity is O(n*m) where n is the number of all subclasses if there's no multiple inheritance
236+
# and m the number of all subclasses belonging to its subclass module.
237+
238+
for subclass in get_all_subclasses(torch.nn.modules.module.Module):
239+
240+
if isinstance(subclass, (Sequential, ModuleList, ModuleDict)):
241+
continue
242+
243+
# if a subclass has already been stored, we should use the cache
244+
if str(subclass) in __STORAGE_META__:
245+
# reset the class import package to its rightfull state.
246+
mods, subclass, meta_class = __STORAGE_META__[subclass]
247+
for mod in mods:
248+
setattr(mod, subclass.__name__, meta_class)
249+
continue
250+
251+
# Create a class subclassing current `subclass` overriding its new method.
252+
# this will enable use to use `torch.distributed.nn.utils.init_meta` to create a `meta`
253+
# version of the current subclass module
254+
class _MetaClass(subclass):
255+
@classmethod
256+
@contextmanager
257+
def instantiation_context(cls, materialize: bool):
258+
_unset_meta_device(from_created=True)
259+
yield
260+
_set_meta_device_populated(from_created=True)
261+
262+
@classmethod
263+
def materialize(cls, materialize_fn: Callable):
264+
with cls.instantiation_context(materialize=True):
265+
obj = materialize_fn()
266+
return obj
267+
268+
@staticmethod
269+
def add_subclasses(subclass):
270+
"""This is used to unrol the instantion tree while creating the modules."""
271+
__CREATED_MODULES__.add(subclass)
272+
if subclass.__bases__[0] != torch.nn.modules.module.Module:
273+
_MetaClass.add_subclasses(subclass.__bases__[0])
274+
275+
def __new__(cls, *args, **kwargs):
276+
subclass = cls.__bases__[0]
277+
cls.add_subclasses(subclass)
278+
with cls.instantiation_context(materialize=False):
279+
obj = init_meta(subclass, *args, **kwargs)
280+
281+
obj.materialize = partial(cls.materialize, materialize_fn=obj.materialize)
282+
return obj
283+
284+
def search(mod: ModuleType) -> List[ModuleType]:
285+
out = []
286+
for _, obj in inspect.getmembers(mod):
287+
if obj == subclass:
288+
out.append(mod)
289+
return out
290+
291+
submodules = subclass.__module__.split(".")
292+
mod = importlib.import_module(submodules[0])
293+
294+
# nn.Module class can be imported at different level and they all need to be mocked.
295+
# Example: torch.nn.Linear is actually torch.nn.modules.linear.Linear
296+
# Therefore, torch.nn.Linear, torch.nn.modules.Linear, torch.nn.modules.linear.Linear
297+
# needs to be replaced by the torch.nn.linear.modules.Linear _MetaClass
298+
out = []
299+
out.append(search(mod))
300+
for name in submodules[1:]:
301+
mod = getattr(mod, name)
302+
out.append(search(mod))
303+
304+
# drop empty module
305+
mods = [mod for mod in chain(*out) if mod]
306+
307+
# store the modules search so it doesn't have to be performed again for this class
308+
__STORAGE_META__[subclass] = (mods, subclass, _MetaClass)
309+
310+
# replace all subclass by its meta form
311+
for mod in mods:
312+
setattr(mod, subclass.__name__, _MetaClass)
313+
314+
315+
@contextmanager
316+
def init_meta_context() -> Generator:
317+
rank_zero_warn(
318+
"Be aware this feature is highly experimental and there are a number of weird edge cases "
319+
"where it can internal assert and/or crash. A more stable version is to be expected from PyTorch 1.11."
320+
)
321+
_set_meta_device()
322+
yield
323+
_unset_meta_device()

0 commit comments

Comments
 (0)