18
18
from functools import partial
19
19
from itertools import chain
20
20
from types import ModuleType
21
- from typing import Callable , Dict , Generator , Iterator , List , Optional , Set , Type
21
+ from typing import Any , Callable , Dict , Generator , Iterator , List , Optional , Set , Type
22
22
23
23
import torch
24
24
from torch import nn , Tensor
25
25
from torch .nn import Module
26
26
from torch .nn .modules .container import ModuleDict , ModuleList , Sequential
27
27
28
+ import pytorch_lightning as pl
28
29
from pytorch_lightning .utilities import rank_zero_warn
29
30
from pytorch_lightning .utilities .exceptions import MisconfigurationException
30
31
from pytorch_lightning .utilities .imports import _TORCH_GREATER_EQUAL_1_10
@@ -191,7 +192,6 @@ def materialize_module(root_module: nn.Module) -> nn.Module:
191
192
192
193
# cache subclasses to optimize the search when resetting the meta device later on.
193
194
__STORAGE_META__ = {}
194
-
195
195
__CREATED_MODULES__ = set ()
196
196
197
197
@@ -237,45 +237,52 @@ def _set_meta_device() -> None:
237
237
238
238
for subclass in get_all_subclasses (torch .nn .modules .module .Module ):
239
239
240
- if isinstance ( subclass , (Sequential , ModuleList , ModuleDict ) ):
240
+ if subclass in (Sequential , ModuleList , ModuleDict , pl . LightningModule ):
241
241
continue
242
242
243
243
# if a subclass has already been stored, we should use the cache
244
244
if str (subclass ) in __STORAGE_META__ :
245
- # reset the class import package to its rightfull state.
245
+ # reset the class import package to its rightful state.
246
246
mods , subclass , meta_class = __STORAGE_META__ [subclass ]
247
247
for mod in mods :
248
248
setattr (mod , subclass .__name__ , meta_class )
249
249
continue
250
250
251
+ class _IsinstanceMetaclass (type (subclass )):
252
+ def __instancecheck__ (self , instance : Any ) -> bool :
253
+ """Overrides the ``isinstance`` check on ``_MaterializerModule`` objects."""
254
+ return isinstance (instance , self .__bases__ [0 ])
255
+
251
256
# Create a class subclassing current `subclass` overriding its new method.
252
257
# this will enable use to use `torch.distributed.nn.utils.init_meta` to create a `meta`
253
258
# version of the current subclass module
254
- class _MetaClass (subclass ):
259
+ class _MaterializerModule (subclass , metaclass = _IsinstanceMetaclass ):
255
260
@classmethod
256
261
@contextmanager
257
- def instantiation_context (cls , materialize : bool ):
262
+ def instantiation_context (cls ):
258
263
_unset_meta_device (from_created = True )
259
264
yield
260
265
_set_meta_device_populated (from_created = True )
261
266
262
267
@classmethod
263
268
def materialize (cls , materialize_fn : Callable ):
264
- with cls .instantiation_context (materialize = True ):
269
+ with cls .instantiation_context ():
265
270
obj = materialize_fn ()
266
271
return obj
267
272
268
273
@staticmethod
269
274
def add_subclasses (subclass ):
270
- """This is used to unrol the instantion tree while creating the modules."""
271
- __CREATED_MODULES__ .add (subclass )
275
+ """This is used to unroll the instantiation tree while creating the modules."""
276
+ # Don't store the LightningModule as skipped from the Meta process.
277
+ if subclass != pl .LightningModule :
278
+ __CREATED_MODULES__ .add (subclass )
272
279
if subclass .__bases__ [0 ] != torch .nn .modules .module .Module :
273
- _MetaClass .add_subclasses (subclass .__bases__ [0 ])
280
+ _MaterializerModule .add_subclasses (subclass .__bases__ [0 ])
274
281
275
282
def __new__ (cls , * args , ** kwargs ):
276
283
subclass = cls .__bases__ [0 ]
277
284
cls .add_subclasses (subclass )
278
- with cls .instantiation_context (materialize = False ):
285
+ with cls .instantiation_context ():
279
286
obj = init_meta (subclass , * args , ** kwargs )
280
287
281
288
obj .materialize = partial (cls .materialize , materialize_fn = obj .materialize )
@@ -294,9 +301,8 @@ def search(mod: ModuleType) -> List[ModuleType]:
294
301
# nn.Module class can be imported at different level and they all need to be mocked.
295
302
# Example: torch.nn.Linear is actually torch.nn.modules.linear.Linear
296
303
# Therefore, torch.nn.Linear, torch.nn.modules.Linear, torch.nn.modules.linear.Linear
297
- # needs to be replaced by the torch.nn.linear.modules.Linear _MetaClass
298
- out = []
299
- out .append (search (mod ))
304
+ # needs to be replaced by the torch.nn.linear.modules.Linear _MaterializerModule
305
+ out = [search (mod )]
300
306
for name in submodules [1 :]:
301
307
mod = getattr (mod , name )
302
308
out .append (search (mod ))
@@ -305,11 +311,11 @@ def search(mod: ModuleType) -> List[ModuleType]:
305
311
mods = [mod for mod in chain (* out ) if mod ]
306
312
307
313
# store the modules search so it doesn't have to be performed again for this class
308
- __STORAGE_META__ [subclass ] = (mods , subclass , _MetaClass )
314
+ __STORAGE_META__ [subclass ] = (mods , subclass , _MaterializerModule )
309
315
310
316
# replace all subclass by its meta form
311
317
for mod in mods :
312
- setattr (mod , subclass .__name__ , _MetaClass )
318
+ setattr (mod , subclass .__name__ , _MaterializerModule )
313
319
314
320
315
321
@contextmanager
@@ -321,3 +327,11 @@ def init_meta_context() -> Generator:
321
327
_set_meta_device ()
322
328
yield
323
329
_unset_meta_device ()
330
+
331
+
332
+ def is_on_meta_device (module : nn .Module ) -> bool :
333
+ try :
334
+ param = next (module .parameters ())
335
+ return param .device .type == "meta"
336
+ except StopIteration :
337
+ return False
0 commit comments