Skip to content

Commit 1de3539

Browse files
authored
Resolve instantiation problem with init_meta_context (#10493)
1 parent ae71284 commit 1de3539

File tree

5 files changed

+58
-21
lines changed

5 files changed

+58
-21
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
142142
- Fixed `CombinedLoader` and `max_size_cycle` didn't receive a `DistributedSampler` ([#10374](https://github.com/PyTorchLightning/pytorch-lightning/issues/10374))
143143

144144

145+
- Fixed `isinstance` not working with `init_meta_context`, materialized model not being moved to the device ([#10493](https://github.com/PyTorchLightning/metrics/pull/10493))
146+
147+
145148
- Fixed an issue that prevented the Trainer to shutdown workers when execution is interrupted due to failure([#10463](https://github.com/PyTorchLightning/pytorch-lightning/issues/10463))
146149

147150

pytorch_lightning/core/mixins/device_dtype_mixin.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import torch
1818
from torch.nn import Module
1919

20+
import pytorch_lightning as pl
21+
2022

2123
class DeviceDtypeModuleMixin(Module):
2224
__jit_unused_properties__ = ["device", "dtype"]
@@ -177,7 +179,9 @@ def __update_properties(
177179
self, device: Optional[torch.device] = None, dtype: Optional[Union[str, torch.dtype]] = None
178180
) -> None:
179181
def apply_fn(module: Union["DeviceDtypeModuleMixin", Module]) -> None:
180-
if not isinstance(module, DeviceDtypeModuleMixin):
182+
# TODO: Find why `isinstance(module, DeviceDtypeModuleMixin)` doesn't
183+
# work when using `init_meta_context`.
184+
if not isinstance(module, (DeviceDtypeModuleMixin, pl.LightningModule)):
181185
return
182186
if device is not None:
183187
module._device = device

pytorch_lightning/trainer/trainer.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@
8484
from pytorch_lightning.utilities.distributed import distributed_available
8585
from pytorch_lightning.utilities.exceptions import ExitGracefullyException, MisconfigurationException
8686
from pytorch_lightning.utilities.imports import _fault_tolerant_training
87-
from pytorch_lightning.utilities.meta import materialize_module
87+
from pytorch_lightning.utilities.meta import is_on_meta_device, materialize_module
8888
from pytorch_lightning.utilities.model_helpers import is_overridden
8989
from pytorch_lightning.utilities.seed import reset_seed
9090
from pytorch_lightning.utilities.types import (
@@ -1406,10 +1406,21 @@ def _call_setup_hook(self) -> None:
14061406

14071407
def _call_configure_sharded_model(self) -> None:
14081408
with self.accelerator.model_sharded_context():
1409-
materialize_module(self.lightning_module)
1409+
self._handle_meta_model()
14101410
self.call_hook("configure_sharded_model")
14111411
self.call_hook("on_configure_sharded_model")
14121412

1413+
def _handle_meta_model(self) -> None:
1414+
if not is_on_meta_device(self.lightning_module):
1415+
return
1416+
1417+
if isinstance(self.training_type_plugin, DDPSpawnPlugin):
1418+
raise MisconfigurationException("LightningModule on meta device isn't supported with spawn.")
1419+
1420+
materialize_module(self.lightning_module)
1421+
# the trainer reference is lost during materialization
1422+
self.lightning_module.trainer = proxy(self)
1423+
14131424
def _call_teardown_hook(self) -> None:
14141425
fn = self.state.fn._setup_fn
14151426

pytorch_lightning/utilities/meta.py

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,14 @@
1818
from functools import partial
1919
from itertools import chain
2020
from types import ModuleType
21-
from typing import Callable, Dict, Generator, Iterator, List, Optional, Set, Type
21+
from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Set, Type
2222

2323
import torch
2424
from torch import nn, Tensor
2525
from torch.nn import Module
2626
from torch.nn.modules.container import ModuleDict, ModuleList, Sequential
2727

28+
import pytorch_lightning as pl
2829
from pytorch_lightning.utilities import rank_zero_warn
2930
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3031
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_10
@@ -191,7 +192,6 @@ def materialize_module(root_module: nn.Module) -> nn.Module:
191192

192193
# cache subclasses to optimize the search when resetting the meta device later on.
193194
__STORAGE_META__ = {}
194-
195195
__CREATED_MODULES__ = set()
196196

197197

@@ -237,45 +237,52 @@ def _set_meta_device() -> None:
237237

238238
for subclass in get_all_subclasses(torch.nn.modules.module.Module):
239239

240-
if isinstance(subclass, (Sequential, ModuleList, ModuleDict)):
240+
if subclass in (Sequential, ModuleList, ModuleDict, pl.LightningModule):
241241
continue
242242

243243
# if a subclass has already been stored, we should use the cache
244244
if str(subclass) in __STORAGE_META__:
245-
# reset the class import package to its rightfull state.
245+
# reset the class import package to its rightful state.
246246
mods, subclass, meta_class = __STORAGE_META__[subclass]
247247
for mod in mods:
248248
setattr(mod, subclass.__name__, meta_class)
249249
continue
250250

251+
class _IsinstanceMetaclass(type(subclass)):
252+
def __instancecheck__(self, instance: Any) -> bool:
253+
"""Overrides the ``isinstance`` check on ``_MaterializerModule`` objects."""
254+
return isinstance(instance, self.__bases__[0])
255+
251256
# Create a class subclassing current `subclass` overriding its new method.
252257
# this will enable use to use `torch.distributed.nn.utils.init_meta` to create a `meta`
253258
# version of the current subclass module
254-
class _MetaClass(subclass):
259+
class _MaterializerModule(subclass, metaclass=_IsinstanceMetaclass):
255260
@classmethod
256261
@contextmanager
257-
def instantiation_context(cls, materialize: bool):
262+
def instantiation_context(cls):
258263
_unset_meta_device(from_created=True)
259264
yield
260265
_set_meta_device_populated(from_created=True)
261266

262267
@classmethod
263268
def materialize(cls, materialize_fn: Callable):
264-
with cls.instantiation_context(materialize=True):
269+
with cls.instantiation_context():
265270
obj = materialize_fn()
266271
return obj
267272

268273
@staticmethod
269274
def add_subclasses(subclass):
270-
"""This is used to unrol the instantion tree while creating the modules."""
271-
__CREATED_MODULES__.add(subclass)
275+
"""This is used to unroll the instantiation tree while creating the modules."""
276+
# Don't store the LightningModule as skipped from the Meta process.
277+
if subclass != pl.LightningModule:
278+
__CREATED_MODULES__.add(subclass)
272279
if subclass.__bases__[0] != torch.nn.modules.module.Module:
273-
_MetaClass.add_subclasses(subclass.__bases__[0])
280+
_MaterializerModule.add_subclasses(subclass.__bases__[0])
274281

275282
def __new__(cls, *args, **kwargs):
276283
subclass = cls.__bases__[0]
277284
cls.add_subclasses(subclass)
278-
with cls.instantiation_context(materialize=False):
285+
with cls.instantiation_context():
279286
obj = init_meta(subclass, *args, **kwargs)
280287

281288
obj.materialize = partial(cls.materialize, materialize_fn=obj.materialize)
@@ -294,9 +301,8 @@ def search(mod: ModuleType) -> List[ModuleType]:
294301
# nn.Module class can be imported at different level and they all need to be mocked.
295302
# Example: torch.nn.Linear is actually torch.nn.modules.linear.Linear
296303
# Therefore, torch.nn.Linear, torch.nn.modules.Linear, torch.nn.modules.linear.Linear
297-
# needs to be replaced by the torch.nn.linear.modules.Linear _MetaClass
298-
out = []
299-
out.append(search(mod))
304+
# needs to be replaced by the torch.nn.linear.modules.Linear _MaterializerModule
305+
out = [search(mod)]
300306
for name in submodules[1:]:
301307
mod = getattr(mod, name)
302308
out.append(search(mod))
@@ -305,11 +311,11 @@ def search(mod: ModuleType) -> List[ModuleType]:
305311
mods = [mod for mod in chain(*out) if mod]
306312

307313
# store the modules search so it doesn't have to be performed again for this class
308-
__STORAGE_META__[subclass] = (mods, subclass, _MetaClass)
314+
__STORAGE_META__[subclass] = (mods, subclass, _MaterializerModule)
309315

310316
# replace all subclass by its meta form
311317
for mod in mods:
312-
setattr(mod, subclass.__name__, _MetaClass)
318+
setattr(mod, subclass.__name__, _MaterializerModule)
313319

314320

315321
@contextmanager
@@ -321,3 +327,11 @@ def init_meta_context() -> Generator:
321327
_set_meta_device()
322328
yield
323329
_unset_meta_device()
330+
331+
332+
def is_on_meta_device(module: nn.Module) -> bool:
333+
try:
334+
param = next(module.parameters())
335+
return param.device.type == "meta"
336+
except StopIteration:
337+
return False

tests/utilities/test_meta.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from torch import nn
1515

1616
from pytorch_lightning.core.lightning import LightningModule
17-
from pytorch_lightning.utilities.meta import init_meta_context, materialize_module
17+
from pytorch_lightning.utilities.meta import init_meta_context, is_on_meta_device, materialize_module
1818
from tests.helpers.runif import RunIf
1919

2020

@@ -31,18 +31,23 @@ def __init__(self, num_layers: int):
3131
self.layer = nn.Sequential(*[nn.Linear(1, 1) for _ in range(self.hparams.num_layers)])
3232

3333

34-
@RunIf(min_torch="1.10.0")
34+
@RunIf(special=True, min_torch="1.10.0")
3535
def test_init_meta_context():
3636

3737
with init_meta_context():
3838
m = nn.Linear(in_features=1, out_features=1)
39+
assert isinstance(m, nn.Linear)
3940
assert m.weight.device.type == "meta"
41+
assert is_on_meta_device(m)
4042
mlp = MLP(4)
4143
assert mlp.layer[0].weight.device.type == "meta"
4244

4345
mlp = materialize_module(mlp)
4446
assert mlp.layer[0].weight.device.type == "cpu"
4547

48+
assert not is_on_meta_device(mlp)
49+
assert not is_on_meta_device(nn.Module())
50+
4651
model = BoringModel(4)
4752
assert model.layer[0].weight.device.type == "meta"
4853
materialize_module(model)

0 commit comments

Comments
 (0)