|
| 1 | +# Copyright The PyTorch Lightning team. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +import importlib |
| 15 | +import inspect |
| 16 | +import threading |
| 17 | +from contextlib import contextmanager |
| 18 | +from functools import partial |
| 19 | +from itertools import chain |
| 20 | +from types import ModuleType |
| 21 | +from typing import Callable, Dict, Generator, Iterator, List, Optional, Set, Type |
| 22 | + |
| 23 | +import torch |
| 24 | +from torch import nn, Tensor |
| 25 | +from torch.nn import Module |
| 26 | +from torch.nn.modules.container import ModuleDict, ModuleList, Sequential |
| 27 | + |
| 28 | +from pytorch_lightning.utilities import rank_zero_warn |
| 29 | +from pytorch_lightning.utilities.exceptions import MisconfigurationException |
| 30 | +from pytorch_lightning.utilities.imports import _TORCH_META_AVAILABLE |
| 31 | + |
| 32 | +if _TORCH_META_AVAILABLE: |
| 33 | + from torch._C import _DisableTorchDispatch # type: ignore[attr-defined] |
| 34 | + |
| 35 | + #################################################################### |
| 36 | + # BELOW: TAKEN FROM https://github.com/pytorch/pytorch/pull/66317. # |
| 37 | + # TODO: Removed once merged and released on PyTorch side # |
| 38 | + #################################################################### |
| 39 | + |
| 40 | + @contextmanager |
| 41 | + def enable_python_mode(cls) -> Iterator[None]: |
| 42 | + if not hasattr(cls, "__torch_dispatch__"): |
| 43 | + raise ValueError("The class passed to enable_python_mode " "must have a __torch_dispatch__ classmethod") |
| 44 | + if not isinstance(cls, type) or not issubclass(cls, (torch.Tensor,)): |
| 45 | + raise ValueError("The argument passed to enable_python_mode " "must be the type of a Tensor subclass") |
| 46 | + torch._C._enter_python_mode(cls) |
| 47 | + try: |
| 48 | + yield |
| 49 | + finally: |
| 50 | + torch._C._exit_python_mode() |
| 51 | + |
| 52 | + _tls = threading.local() |
| 53 | + _tls.in_call = False |
| 54 | + |
| 55 | + @contextmanager |
| 56 | + def _no_dispatch() -> Iterator[None]: |
| 57 | + """Temporarily disables the Python dispatch mode.""" |
| 58 | + guard = _DisableTorchDispatch() # noqa F841 |
| 59 | + try: |
| 60 | + yield |
| 61 | + finally: |
| 62 | + del guard |
| 63 | + |
| 64 | + def _handle_arange(func, args, kwargs): |
| 65 | + kwargs["device"] = torch.device("cpu") |
| 66 | + return torch.empty_like(func(*args, **kwargs), device="meta") |
| 67 | + |
| 68 | + def _handle_tril(func, args, kwargs): |
| 69 | + if args and isinstance(args[0], Tensor): |
| 70 | + return torch.empty_like(args[0], device="meta") |
| 71 | + |
| 72 | + return NotImplemented |
| 73 | + |
| 74 | + class _MetaContext(Tensor): |
| 75 | + _op_handlers: Dict[Callable, Callable] = {} |
| 76 | + |
| 77 | + @classmethod |
| 78 | + def _ensure_handlers_initialized(cls) -> None: |
| 79 | + if cls._op_handlers: |
| 80 | + return |
| 81 | + |
| 82 | + cls._op_handlers.update( |
| 83 | + { |
| 84 | + torch.ops.aten.arange: _handle_arange, |
| 85 | + torch.ops.aten.tril: _handle_tril, |
| 86 | + } |
| 87 | + ) |
| 88 | + |
| 89 | + @classmethod |
| 90 | + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): |
| 91 | + cls._ensure_handlers_initialized() |
| 92 | + |
| 93 | + op_handler: Optional[Callable] |
| 94 | + |
| 95 | + try: |
| 96 | + op_handler = cls._op_handlers[func] |
| 97 | + except KeyError: |
| 98 | + op_handler = None |
| 99 | + |
| 100 | + with _no_dispatch(): |
| 101 | + if op_handler: |
| 102 | + result = op_handler(func, args, kwargs) |
| 103 | + if result is not NotImplemented: |
| 104 | + return result |
| 105 | + |
| 106 | + if "device" in kwargs: |
| 107 | + kwargs["device"] = torch.device("meta") |
| 108 | + |
| 109 | + return func(*args, **kwargs) |
| 110 | + |
| 111 | + def init_meta(module_fn: Callable[..., Module], *args, **kwargs) -> Module: |
| 112 | + def create_instance(module=None) -> Module: |
| 113 | + if module: |
| 114 | + module.__init__(*args, **kwargs) |
| 115 | + return module |
| 116 | + return module_fn(*args, **kwargs) |
| 117 | + |
| 118 | + if _tls.in_call: |
| 119 | + module = create_instance() |
| 120 | + else: |
| 121 | + _tls.in_call = True |
| 122 | + try: |
| 123 | + with enable_python_mode(_MetaContext): |
| 124 | + module = create_instance() |
| 125 | + finally: |
| 126 | + _tls.in_call = False |
| 127 | + |
| 128 | + module.materialize = partial(create_instance, module=module) # type: ignore[assignment] |
| 129 | + |
| 130 | + return module |
| 131 | + |
| 132 | + def is_meta_init() -> bool: |
| 133 | + """Indicates whether the module is being instantiated by ``init_meta()``.""" |
| 134 | + return _tls.in_call |
| 135 | + |
| 136 | + #################################################################### |
| 137 | + # ABOVE: TAKEN FROM https://github.com/pytorch/pytorch/pull/66317. # |
| 138 | + # TODO: Removed once merged and released on PyTorch side # |
| 139 | + #################################################################### |
| 140 | + |
| 141 | +else: |
| 142 | + |
| 143 | + def init_meta(*_, **__): |
| 144 | + if not _TORCH_META_AVAILABLE: |
| 145 | + return MisconfigurationException("`init_meta` is supported from PyTorch 1.10.0") |
| 146 | + |
| 147 | + |
| 148 | +# https://stackoverflow.com/a/63851681/9201239 |
| 149 | +def get_all_subclasses(cls: Type[nn.Module]) -> Set[nn.Module]: |
| 150 | + subclass_list = [] |
| 151 | + |
| 152 | + def recurse(cl): |
| 153 | + for subclass in cl.__subclasses__(): |
| 154 | + subclass_list.append(subclass) |
| 155 | + recurse(subclass) |
| 156 | + |
| 157 | + recurse(cls) |
| 158 | + |
| 159 | + return set(subclass_list) |
| 160 | + |
| 161 | + |
| 162 | +def recursively_setattr(root_module: nn.Module, prefix: str, materialized_module: nn.Module) -> None: |
| 163 | + *path, name = prefix.split(".") |
| 164 | + for p in path: |
| 165 | + root_module = getattr(root_module, p) |
| 166 | + |
| 167 | + try: |
| 168 | + index = int(name) |
| 169 | + root_module[index] = materialized_module |
| 170 | + except ValueError: |
| 171 | + setattr(root_module, name, materialized_module) |
| 172 | + |
| 173 | + |
| 174 | +def materialize_module(root_module: nn.Module) -> nn.Module: |
| 175 | + """This utility performs an in-place operation by materialize a module and its children.""" |
| 176 | + if not _TORCH_META_AVAILABLE: |
| 177 | + return root_module |
| 178 | + |
| 179 | + materialize_fn = getattr(root_module, "materialize", None) |
| 180 | + if materialize_fn and not isinstance(root_module, (Sequential, ModuleList, ModuleDict)): |
| 181 | + return materialize_fn() |
| 182 | + |
| 183 | + for name, child in root_module.named_children(): |
| 184 | + materialize_fn = getattr(child, "materialize", None) |
| 185 | + if not materialize_fn or isinstance(child, (Sequential, ModuleList, ModuleDict)): |
| 186 | + materialize_module(child) |
| 187 | + else: |
| 188 | + setattr(child, name, materialize_fn()) |
| 189 | + return root_module |
| 190 | + |
| 191 | + |
| 192 | +# cache subclasses to optimize the search when resetting the meta device later on. |
| 193 | +__STORAGE_META__ = {} |
| 194 | + |
| 195 | +__CREATED_MODULES__ = set() |
| 196 | + |
| 197 | + |
| 198 | +def _unset_meta_device(from_created: bool = False) -> None: |
| 199 | + """Replace all meta module by their original version.""" |
| 200 | + if not _TORCH_META_AVAILABLE: |
| 201 | + raise MisconfigurationException("`init_meta` is supported from PyTorch 1.10.0") |
| 202 | + |
| 203 | + if from_created: |
| 204 | + values = [__STORAGE_META__[key] for key in __CREATED_MODULES__] |
| 205 | + else: |
| 206 | + values = __STORAGE_META__.values() |
| 207 | + |
| 208 | + for mods, subclass, _ in values: |
| 209 | + for mod in mods: |
| 210 | + setattr(mod, subclass.__name__, subclass) |
| 211 | + |
| 212 | + |
| 213 | +def _set_meta_device_populated(from_created: bool = False) -> None: |
| 214 | + """Replace all meta module by their original version.""" |
| 215 | + if not _TORCH_META_AVAILABLE: |
| 216 | + raise MisconfigurationException("`init_meta` is supported from PyTorch 1.10.0") |
| 217 | + |
| 218 | + if from_created: |
| 219 | + values = [__STORAGE_META__[key] for key in __CREATED_MODULES__] |
| 220 | + else: |
| 221 | + values = __STORAGE_META__.values() |
| 222 | + |
| 223 | + for mods, subclass, meta_class in values: |
| 224 | + for mod in mods: |
| 225 | + setattr(mod, subclass.__name__, meta_class) |
| 226 | + |
| 227 | + |
| 228 | +def _set_meta_device() -> None: |
| 229 | + """Replace all torch.nn.Module by their meta replacement.""" |
| 230 | + |
| 231 | + if not _TORCH_META_AVAILABLE: |
| 232 | + raise MisconfigurationException("`init_meta` is supported from PyTorch 1.10.0") |
| 233 | + |
| 234 | + # Author note: This can be optimized further by searching all subclasses at once. |
| 235 | + # Its time complexity is O(n*m) where n is the number of all subclasses if there's no multiple inheritance |
| 236 | + # and m the number of all subclasses belonging to its subclass module. |
| 237 | + |
| 238 | + for subclass in get_all_subclasses(torch.nn.modules.module.Module): |
| 239 | + |
| 240 | + if isinstance(subclass, (Sequential, ModuleList, ModuleDict)): |
| 241 | + continue |
| 242 | + |
| 243 | + # if a subclass has already been stored, we should use the cache |
| 244 | + if str(subclass) in __STORAGE_META__: |
| 245 | + # reset the class import package to its rightfull state. |
| 246 | + mods, subclass, meta_class = __STORAGE_META__[subclass] |
| 247 | + for mod in mods: |
| 248 | + setattr(mod, subclass.__name__, meta_class) |
| 249 | + continue |
| 250 | + |
| 251 | + # Create a class subclassing current `subclass` overriding its new method. |
| 252 | + # this will enable use to use `torch.distributed.nn.utils.init_meta` to create a `meta` |
| 253 | + # version of the current subclass module |
| 254 | + class _MetaClass(subclass): |
| 255 | + @classmethod |
| 256 | + @contextmanager |
| 257 | + def instantiation_context(cls, materialize: bool): |
| 258 | + _unset_meta_device(from_created=True) |
| 259 | + yield |
| 260 | + _set_meta_device_populated(from_created=True) |
| 261 | + |
| 262 | + @classmethod |
| 263 | + def materialize(cls, materialize_fn: Callable): |
| 264 | + with cls.instantiation_context(materialize=True): |
| 265 | + obj = materialize_fn() |
| 266 | + return obj |
| 267 | + |
| 268 | + @staticmethod |
| 269 | + def add_subclasses(subclass): |
| 270 | + """This is used to unrol the instantion tree while creating the modules.""" |
| 271 | + __CREATED_MODULES__.add(subclass) |
| 272 | + if subclass.__bases__[0] != torch.nn.modules.module.Module: |
| 273 | + _MetaClass.add_subclasses(subclass.__bases__[0]) |
| 274 | + |
| 275 | + def __new__(cls, *args, **kwargs): |
| 276 | + subclass = cls.__bases__[0] |
| 277 | + cls.add_subclasses(subclass) |
| 278 | + with cls.instantiation_context(materialize=False): |
| 279 | + obj = init_meta(subclass, *args, **kwargs) |
| 280 | + |
| 281 | + obj.materialize = partial(cls.materialize, materialize_fn=obj.materialize) |
| 282 | + return obj |
| 283 | + |
| 284 | + def search(mod: ModuleType) -> List[ModuleType]: |
| 285 | + out = [] |
| 286 | + for _, obj in inspect.getmembers(mod): |
| 287 | + if obj == subclass: |
| 288 | + out.append(mod) |
| 289 | + return out |
| 290 | + |
| 291 | + submodules = subclass.__module__.split(".") |
| 292 | + mod = importlib.import_module(submodules[0]) |
| 293 | + |
| 294 | + # nn.Module class can be imported at different level and they all need to be mocked. |
| 295 | + # Example: torch.nn.Linear is actually torch.nn.modules.linear.Linear |
| 296 | + # Therefore, torch.nn.Linear, torch.nn.modules.Linear, torch.nn.modules.linear.Linear |
| 297 | + # needs to be replaced by the torch.nn.linear.modules.Linear _MetaClass |
| 298 | + out = [] |
| 299 | + out.append(search(mod)) |
| 300 | + for name in submodules[1:]: |
| 301 | + mod = getattr(mod, name) |
| 302 | + out.append(search(mod)) |
| 303 | + |
| 304 | + # drop empty module |
| 305 | + mods = [mod for mod in chain(*out) if mod] |
| 306 | + |
| 307 | + # store the modules search so it doesn't have to be performed again for this class |
| 308 | + __STORAGE_META__[subclass] = (mods, subclass, _MetaClass) |
| 309 | + |
| 310 | + # replace all subclass by its meta form |
| 311 | + for mod in mods: |
| 312 | + setattr(mod, subclass.__name__, _MetaClass) |
| 313 | + |
| 314 | + |
| 315 | +@contextmanager |
| 316 | +def init_meta_context() -> Generator: |
| 317 | + rank_zero_warn( |
| 318 | + "Be aware this feature is highly experimental and there are a number of weird edge cases " |
| 319 | + "where it can internal assert and/or crash. A more stable version is to be expected from PyTorch 1.11." |
| 320 | + ) |
| 321 | + _set_meta_device() |
| 322 | + yield |
| 323 | + _unset_meta_device() |
0 commit comments