|
| 1 | +# Copyright The PyTorch Lightning team. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +import importlib |
| 15 | +import os |
| 16 | +from collections import UserDict |
| 17 | +from inspect import getmembers, isclass |
| 18 | +from pathlib import Path |
| 19 | +from typing import Any, Callable, List, Optional |
| 20 | + |
| 21 | +from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin |
| 22 | +from pytorch_lightning.utilities.exceptions import MisconfigurationException |
| 23 | + |
| 24 | + |
| 25 | +class _TrainingTypePluginsRegistry(UserDict): |
| 26 | + """ |
| 27 | + This class is a Registry that stores information about the Training Type Plugins. |
| 28 | +
|
| 29 | + The Plugins are mapped to strings. These strings are names that idenitify |
| 30 | + a plugin, e.g., "deepspeed". It also returns Optional description and |
| 31 | + parameters to initialize the Plugin, which were defined durng the |
| 32 | + registration. |
| 33 | +
|
| 34 | + The motivation for having a TrainingTypePluginRegistry is to make it convenient |
| 35 | + for the Users to try different Plugins by passing just strings |
| 36 | + to the plugins flag to the Trainer. |
| 37 | +
|
| 38 | + Example:: |
| 39 | +
|
| 40 | + @TrainingTypePluginsRegistry.register("lightning", description="Super fast", a=1, b=True) |
| 41 | + class LightningPlugin: |
| 42 | + def __init__(self, a, b): |
| 43 | + ... |
| 44 | +
|
| 45 | + or |
| 46 | +
|
| 47 | + TrainingTypePluginsRegistry.register("lightning", LightningPlugin, description="Super fast", a=1, b=True) |
| 48 | +
|
| 49 | + """ |
| 50 | + |
| 51 | + def register( |
| 52 | + self, |
| 53 | + name: str, |
| 54 | + plugin: Optional[Callable] = None, |
| 55 | + description: Optional[str] = None, |
| 56 | + override: bool = False, |
| 57 | + **init_params: Any, |
| 58 | + ) -> Callable: |
| 59 | + """ |
| 60 | + Registers a plugin mapped to a name and with required metadata. |
| 61 | +
|
| 62 | + Args: |
| 63 | + name : the name that identifies a plugin, e.g. "deepspeed_stage_3" |
| 64 | + plugin : plugin class |
| 65 | + description : plugin description |
| 66 | + override : overrides the registered plugin, if True |
| 67 | + init_params: parameters to initialize the plugin |
| 68 | + """ |
| 69 | + if not (name is None or isinstance(name, str)): |
| 70 | + raise TypeError(f'`name` must be a str, found {name}') |
| 71 | + |
| 72 | + if name in self and not override: |
| 73 | + raise MisconfigurationException( |
| 74 | + f"'{name}' is already present in the registry." |
| 75 | + " HINT: Use `override=True`." |
| 76 | + ) |
| 77 | + |
| 78 | + data = {} |
| 79 | + data["description"] = description if description is not None else "" |
| 80 | + |
| 81 | + data["init_params"] = init_params |
| 82 | + |
| 83 | + def do_register(plugin: Callable) -> Callable: |
| 84 | + data["plugin"] = plugin |
| 85 | + self[name] = data |
| 86 | + return plugin |
| 87 | + |
| 88 | + if plugin is not None: |
| 89 | + return do_register(plugin) |
| 90 | + |
| 91 | + return do_register |
| 92 | + |
| 93 | + def get(self, name: str) -> Any: |
| 94 | + """ |
| 95 | + Calls the registered plugin with the required parameters |
| 96 | + and returns the plugin object |
| 97 | +
|
| 98 | + Args: |
| 99 | + name (str): the name that identifies a plugin, e.g. "deepspeed_stage_3" |
| 100 | + """ |
| 101 | + if name in self: |
| 102 | + data = self[name] |
| 103 | + return data["plugin"](**data["init_params"]) |
| 104 | + |
| 105 | + err_msg = "'{}' not found in registry. Available names: {}" |
| 106 | + available_names = ", ".join(sorted(self.keys())) or "none" |
| 107 | + raise KeyError(err_msg.format(name, available_names)) |
| 108 | + |
| 109 | + def remove(self, name: str) -> None: |
| 110 | + """Removes the registered plugin by name""" |
| 111 | + self.pop(name) |
| 112 | + |
| 113 | + def available_plugins(self) -> List: |
| 114 | + """Returns a list of registered plugins""" |
| 115 | + return list(self.keys()) |
| 116 | + |
| 117 | + def __str__(self) -> str: |
| 118 | + return "Registered Plugins: {}".format(", ".join(self.keys())) |
| 119 | + |
| 120 | + |
| 121 | +TrainingTypePluginsRegistry = _TrainingTypePluginsRegistry() |
| 122 | + |
| 123 | + |
| 124 | +def is_register_plugins_overridden(plugin: Callable) -> bool: |
| 125 | + method_name = "register_plugins" |
| 126 | + plugin_attr = getattr(plugin, method_name) |
| 127 | + super_attr = getattr(TrainingTypePlugin, method_name) |
| 128 | + |
| 129 | + if hasattr(plugin_attr, 'patch_loader_code'): |
| 130 | + is_overridden = plugin_attr.patch_loader_code != str(super_attr.__code__) |
| 131 | + else: |
| 132 | + is_overridden = plugin_attr.__code__ is not super_attr.__code__ |
| 133 | + return is_overridden |
| 134 | + |
| 135 | + |
| 136 | +def call_training_type_register_plugins(root: Path, base_module: str) -> None: |
| 137 | + # Ref: https://github.com/facebookresearch/ClassyVision/blob/master/classy_vision/generic/registry_utils.py#L14 |
| 138 | + directory = "training_type" |
| 139 | + for file in os.listdir(root / directory): |
| 140 | + if file.endswith(".py") and not file.startswith("_"): |
| 141 | + module = file[:file.find(".py")] |
| 142 | + module = importlib.import_module(".".join([base_module, module])) |
| 143 | + for _, mod in getmembers(module, isclass): |
| 144 | + if issubclass(mod, TrainingTypePlugin) and is_register_plugins_overridden(mod): |
| 145 | + mod.register_plugins(TrainingTypePluginsRegistry) |
| 146 | + break |
0 commit comments