diff --git a/CHANGELOG.md b/CHANGELOG.md index b90cf41b0736b..031b573461302 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -164,6 +164,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `Callback.state_dict()` and `Callback.load_state_dict()` methods ([#12232](https://github.com/PyTorchLightning/pytorch-lightning/pull/12232)) +- Added `AcceleratorRegistry` ([#12180](https://github.com/PyTorchLightning/pytorch-lightning/pull/12180)) + + ### Changed - Drop PyTorch 1.7 support ([#12191](https://github.com/PyTorchLightning/pytorch-lightning/pull/12191)) diff --git a/pytorch_lightning/accelerators/__init__.py b/pytorch_lightning/accelerators/__init__.py index 1c9e0024f39bd..fe9ae0a120cfb 100644 --- a/pytorch_lightning/accelerators/__init__.py +++ b/pytorch_lightning/accelerators/__init__.py @@ -14,4 +14,9 @@ from pytorch_lightning.accelerators.cpu import CPUAccelerator # noqa: F401 from pytorch_lightning.accelerators.gpu import GPUAccelerator # noqa: F401 from pytorch_lightning.accelerators.ipu import IPUAccelerator # noqa: F401 +from pytorch_lightning.accelerators.registry import AcceleratorRegistry, call_register_accelerators # noqa: F401 from pytorch_lightning.accelerators.tpu import TPUAccelerator # noqa: F401 + +ACCELERATORS_BASE_MODULE = "pytorch_lightning.accelerators" + +call_register_accelerators(ACCELERATORS_BASE_MODULE) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index ad0779d88b96c..526cec3e47319 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -75,7 +75,6 @@ def auto_device_count() -> int: def is_available() -> bool: """Detect if the hardware is available.""" - @staticmethod - @abstractmethod - def name() -> str: - """Name of the Accelerator.""" + @classmethod + def register_accelerators(cls, accelerator_registry: Dict) -> None: + pass diff --git a/pytorch_lightning/accelerators/cpu.py b/pytorch_lightning/accelerators/cpu.py index a027e7db6e209..3d28a4d80f682 100644 --- a/pytorch_lightning/accelerators/cpu.py +++ b/pytorch_lightning/accelerators/cpu.py @@ -63,7 +63,10 @@ def is_available() -> bool: """CPU is always available for execution.""" return True - @staticmethod - def name() -> str: - """Name of the Accelerator.""" - return "cpu" + @classmethod + def register_accelerators(cls, accelerator_registry: Dict) -> None: + accelerator_registry.register( + "cpu", + cls, + description=f"{cls.__class__.__name__}", + ) diff --git a/pytorch_lightning/accelerators/gpu.py b/pytorch_lightning/accelerators/gpu.py index 529d067025f97..1f74da7da3f4e 100644 --- a/pytorch_lightning/accelerators/gpu.py +++ b/pytorch_lightning/accelerators/gpu.py @@ -93,10 +93,13 @@ def auto_device_count() -> int: def is_available() -> bool: return torch.cuda.device_count() > 0 - @staticmethod - def name() -> str: - """Name of the Accelerator.""" - return "gpu" + @classmethod + def register_accelerators(cls, accelerator_registry: Dict) -> None: + accelerator_registry.register( + "gpu", + cls, + description=f"{cls.__class__.__name__}", + ) def get_nvidia_gpu_stats(device: _DEVICE) -> Dict[str, float]: diff --git a/pytorch_lightning/accelerators/ipu.py b/pytorch_lightning/accelerators/ipu.py index 1e8b2bc27fe57..b5110e58028a5 100644 --- a/pytorch_lightning/accelerators/ipu.py +++ b/pytorch_lightning/accelerators/ipu.py @@ -47,7 +47,10 @@ def auto_device_count() -> int: def is_available() -> bool: return _IPU_AVAILABLE - @staticmethod - def name() -> str: - """Name of the Accelerator.""" - return "ipu" + @classmethod + def register_accelerators(cls, accelerator_registry: Dict) -> None: + accelerator_registry.register( + "ipu", + cls, + description=f"{cls.__class__.__name__}", + ) diff --git a/pytorch_lightning/accelerators/registry.py b/pytorch_lightning/accelerators/registry.py new file mode 100644 index 0000000000000..992fa34b02aee --- /dev/null +++ b/pytorch_lightning/accelerators/registry.py @@ -0,0 +1,122 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib +from inspect import getmembers, isclass +from typing import Any, Callable, Dict, List, Optional + +from pytorch_lightning.accelerators.accelerator import Accelerator +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.registry import _is_register_method_overridden + + +class _AcceleratorRegistry(dict): + """This class is a Registry that stores information about the Accelerators. + + The Accelerators are mapped to strings. These strings are names that identify + an accelerator, e.g., "gpu". It also returns Optional description and + parameters to initialize the Accelerator, which were defined during the + registration. + + The motivation for having a AcceleratorRegistry is to make it convenient + for the Users to try different accelerators by passing mapped aliases + to the accelerator flag to the Trainer. + + Example:: + + @AcceleratorRegistry.register("sota", description="Custom sota accelerator", a=1, b=True) + class SOTAAccelerator(Accelerator): + def __init__(self, a, b): + ... + + or + + AcceleratorRegistry.register("sota", SOTAAccelerator, description="Custom sota accelerator", a=1, b=True) + """ + + def register( + self, + name: str, + accelerator: Optional[Callable] = None, + description: str = "", + override: bool = False, + **init_params: Any, + ) -> Callable: + """Registers a accelerator mapped to a name and with required metadata. + + Args: + name : the name that identifies a accelerator, e.g. "gpu" + accelerator : accelerator class + description : accelerator description + override : overrides the registered accelerator, if True + init_params: parameters to initialize the accelerator + """ + if not (name is None or isinstance(name, str)): + raise TypeError(f"`name` must be a str, found {name}") + + if name in self and not override: + raise MisconfigurationException(f"'{name}' is already present in the registry. HINT: Use `override=True`.") + + data: Dict[str, Any] = {} + + data["description"] = description + data["init_params"] = init_params + + def do_register(name: str, accelerator: Callable) -> Callable: + data["accelerator"] = accelerator + data["accelerator_name"] = name + self[name] = data + return accelerator + + if accelerator is not None: + return do_register(name, accelerator) + + return do_register + + def get(self, name: str, default: Optional[Any] = None) -> Any: + """Calls the registered accelerator with the required parameters and returns the accelerator object. + + Args: + name (str): the name that identifies a accelerator, e.g. "gpu" + """ + if name in self: + data = self[name] + return data["accelerator"](**data["init_params"]) + + if default is not None: + return default + + err_msg = "'{}' not found in registry. Available names: {}" + available_names = self.available_accelerators() + raise KeyError(err_msg.format(name, available_names)) + + def remove(self, name: str) -> None: + """Removes the registered accelerator by name.""" + self.pop(name) + + def available_accelerators(self) -> List[str]: + """Returns a list of registered accelerators.""" + return list(self.keys()) + + def __str__(self) -> str: + return "Registered Accelerators: {}".format(", ".join(self.available_accelerators())) + + +AcceleratorRegistry = _AcceleratorRegistry() + + +def call_register_accelerators(base_module: str) -> None: + module = importlib.import_module(base_module) + for _, mod in getmembers(module, isclass): + if issubclass(mod, Accelerator) and _is_register_method_overridden(mod, Accelerator, "register_accelerators"): + mod.register_accelerators(AcceleratorRegistry) diff --git a/pytorch_lightning/accelerators/tpu.py b/pytorch_lightning/accelerators/tpu.py index dfdc950e70124..fa8bd007cb25f 100644 --- a/pytorch_lightning/accelerators/tpu.py +++ b/pytorch_lightning/accelerators/tpu.py @@ -65,7 +65,10 @@ def auto_device_count() -> int: def is_available() -> bool: return _TPU_AVAILABLE - @staticmethod - def name() -> str: - """Name of the Accelerator.""" - return "tpu" + @classmethod + def register_accelerators(cls, accelerator_registry: Dict) -> None: + accelerator_registry.register( + "tpu", + cls, + description=f"{cls.__class__.__name__}", + ) diff --git a/pytorch_lightning/strategies/__init__.py b/pytorch_lightning/strategies/__init__.py index f06edfa53ec7a..a4cd57a50ac1d 100644 --- a/pytorch_lightning/strategies/__init__.py +++ b/pytorch_lightning/strategies/__init__.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from pathlib import Path - from pytorch_lightning.strategies.bagua import BaguaStrategy # noqa: F401 from pytorch_lightning.strategies.ddp import DDPStrategy # noqa: F401 from pytorch_lightning.strategies.ddp2 import DDP2Strategy # noqa: F401 @@ -31,7 +29,6 @@ from pytorch_lightning.strategies.strategy_registry import call_register_strategies, StrategyRegistry # noqa: F401 from pytorch_lightning.strategies.tpu_spawn import TPUSpawnStrategy # noqa: F401 -FILE_ROOT = Path(__file__).parent STRATEGIES_BASE_MODULE = "pytorch_lightning.strategies" -call_register_strategies(FILE_ROOT, STRATEGIES_BASE_MODULE) +call_register_strategies(STRATEGIES_BASE_MODULE) diff --git a/pytorch_lightning/strategies/strategy_registry.py b/pytorch_lightning/strategies/strategy_registry.py index 17e08acb23bcc..7dee7146d415d 100644 --- a/pytorch_lightning/strategies/strategy_registry.py +++ b/pytorch_lightning/strategies/strategy_registry.py @@ -12,13 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import importlib -import inspect from inspect import getmembers, isclass -from pathlib import Path from typing import Any, Callable, Dict, List, Optional from pytorch_lightning.strategies.strategy import Strategy from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.registry import _is_register_method_overridden class _StrategyRegistry(dict): @@ -116,22 +115,8 @@ def __str__(self) -> str: StrategyRegistry = _StrategyRegistry() -def is_register_strategies_overridden(strategy: type) -> bool: - - method_name = "register_strategies" - strategy_attr = getattr(strategy, method_name) - previous_super_cls = inspect.getmro(strategy)[1] - - if issubclass(previous_super_cls, Strategy): - super_attr = getattr(previous_super_cls, method_name) - else: - return False - - return strategy_attr.__code__ is not super_attr.__code__ - - -def call_register_strategies(root: Path, base_module: str) -> None: +def call_register_strategies(base_module: str) -> None: module = importlib.import_module(base_module) for _, mod in getmembers(module, isclass): - if issubclass(mod, Strategy) and is_register_strategies_overridden(mod): + if issubclass(mod, Strategy) and _is_register_method_overridden(mod, Strategy, "register_strategies"): mod.register_strategies(StrategyRegistry) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 0d2013c1606cf..e1cf4c6232f90 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -23,6 +23,7 @@ from pytorch_lightning.accelerators.cpu import CPUAccelerator from pytorch_lightning.accelerators.gpu import GPUAccelerator from pytorch_lightning.accelerators.ipu import IPUAccelerator +from pytorch_lightning.accelerators.registry import AcceleratorRegistry from pytorch_lightning.accelerators.tpu import TPUAccelerator from pytorch_lightning.plugins import ( ApexMixedPrecisionPlugin, @@ -157,7 +158,7 @@ def __init__( # 1. Parsing flags # Get registered strategies, built-in accelerators and precision plugins self._registered_strategies = StrategyRegistry.available_strategies() - self._accelerator_types = ("tpu", "ipu", "gpu", "cpu") + self._accelerator_types = AcceleratorRegistry.available_accelerators() self._precision_types = ("16", "32", "64", "bf16", "mixed") # Raise an exception if there are conflicts between flags @@ -486,32 +487,27 @@ def _choose_accelerator(self) -> str: return "cpu" def _set_parallel_devices_and_init_accelerator(self) -> None: - ACCELERATORS = { - "cpu": CPUAccelerator, - "gpu": GPUAccelerator, - "tpu": TPUAccelerator, - "ipu": IPUAccelerator, - } if isinstance(self._accelerator_flag, Accelerator): self.accelerator: Accelerator = self._accelerator_flag else: assert self._accelerator_flag is not None self._accelerator_flag = self._accelerator_flag.lower() - if self._accelerator_flag not in ACCELERATORS: + if self._accelerator_flag not in AcceleratorRegistry: raise MisconfigurationException( "When passing string value for the `accelerator` argument of `Trainer`," - f" it can only be one of {list(ACCELERATORS)}." + f" it can only be one of {self._accelerator_types}." ) - accelerator_class = ACCELERATORS[self._accelerator_flag] - self.accelerator = accelerator_class() # type: ignore[abstract] + self.accelerator = AcceleratorRegistry.get(self._accelerator_flag) if not self.accelerator.is_available(): - available_accelerator = [acc_str for acc_str in list(ACCELERATORS) if ACCELERATORS[acc_str].is_available()] + available_accelerator = [ + acc_str for acc_str in self._accelerator_types if AcceleratorRegistry.get(acc_str).is_available() + ] raise MisconfigurationException( f"{self.accelerator.__class__.__qualname__} can not run on your system" - f" since {self.accelerator.name().upper()}s are not available." - " The following accelerator(s) is available and can be passed into" - f" `accelerator` argument of `Trainer`: {available_accelerator}." + " since the accelerator is not available. The following accelerator(s)" + " is available and can be passed into `accelerator` argument of" + f" `Trainer`: {available_accelerator}." ) self._set_devices_flag_if_auto_passed() diff --git a/pytorch_lightning/utilities/registry.py b/pytorch_lightning/utilities/registry.py new file mode 100644 index 0000000000000..83970e885bdcd --- /dev/null +++ b/pytorch_lightning/utilities/registry.py @@ -0,0 +1,27 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +from typing import Any + + +def _is_register_method_overridden(mod: type, base_cls: Any, method: str) -> bool: + mod_attr = getattr(mod, method) + previous_super_cls = inspect.getmro(mod)[1] + + if issubclass(previous_super_cls, base_cls): + super_attr = getattr(previous_super_cls, method) + else: + return False + + return mod_attr.__code__ is not super_attr.__code__ diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index fce6a7fae2502..794cb9b2922cd 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -498,7 +498,8 @@ def test_accelerator_cpu(_): with pytest.raises(MisconfigurationException, match="You requested gpu:"): trainer = Trainer(gpus=1) with pytest.raises( - MisconfigurationException, match="GPUAccelerator can not run on your system since GPUs are not available." + MisconfigurationException, + match="GPUAccelerator can not run on your system since the accelerator is not available.", ): trainer = Trainer(accelerator="gpu") with pytest.raises(MisconfigurationException, match="You requested gpu:"): diff --git a/tests/accelerators/test_accelerator_registry.py b/tests/accelerators/test_accelerator_registry.py new file mode 100644 index 0000000000000..b21cd95e33cbd --- /dev/null +++ b/tests/accelerators/test_accelerator_registry.py @@ -0,0 +1,66 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pytorch_lightning import Trainer +from pytorch_lightning.accelerators import Accelerator, AcceleratorRegistry + + +def test_accelerator_registry_with_new_accelerator(): + + accelerator_name = "custom_accelerator" + accelerator_description = "Custom Accelerator" + + class CustomAccelerator(Accelerator): + def __init__(self, param1, param2): + self.param1 = param1 + self.param2 = param2 + super().__init__() + + @staticmethod + def parse_devices(devices): + return devices + + @staticmethod + def get_parallel_devices(devices): + return ["foo"] * devices + + @staticmethod + def auto_device_count(): + return 3 + + @staticmethod + def is_available(): + return True + + AcceleratorRegistry.register( + accelerator_name, CustomAccelerator, description=accelerator_description, param1="abc", param2=123 + ) + + assert accelerator_name in AcceleratorRegistry + + assert AcceleratorRegistry[accelerator_name]["description"] == accelerator_description + assert AcceleratorRegistry[accelerator_name]["init_params"] == {"param1": "abc", "param2": 123} + assert AcceleratorRegistry[accelerator_name]["accelerator_name"] == accelerator_name + + assert isinstance(AcceleratorRegistry.get(accelerator_name), CustomAccelerator) + + trainer = Trainer(accelerator=accelerator_name, devices="auto") + assert isinstance(trainer.accelerator, CustomAccelerator) + assert trainer._accelerator_connector.parallel_devices == ["foo"] * 3 + + AcceleratorRegistry.remove(accelerator_name) + assert accelerator_name not in AcceleratorRegistry + + +def test_available_accelerators_in_registry(): + assert AcceleratorRegistry.available_accelerators() == ["cpu", "gpu", "ipu", "tpu"]