Skip to content

Add AcceleratorRegistry #12180

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Mar 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `Callback.state_dict()` and `Callback.load_state_dict()` methods ([#12232](https://github.com/PyTorchLightning/pytorch-lightning/pull/12232))


- Added `AcceleratorRegistry` ([#12180](https://github.com/PyTorchLightning/pytorch-lightning/pull/12180))


### Changed

- Drop PyTorch 1.7 support ([#12191](https://github.com/PyTorchLightning/pytorch-lightning/pull/12191))
Expand Down
5 changes: 5 additions & 0 deletions pytorch_lightning/accelerators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,9 @@
from pytorch_lightning.accelerators.cpu import CPUAccelerator # noqa: F401
from pytorch_lightning.accelerators.gpu import GPUAccelerator # noqa: F401
from pytorch_lightning.accelerators.ipu import IPUAccelerator # noqa: F401
from pytorch_lightning.accelerators.registry import AcceleratorRegistry, call_register_accelerators # noqa: F401
from pytorch_lightning.accelerators.tpu import TPUAccelerator # noqa: F401

ACCELERATORS_BASE_MODULE = "pytorch_lightning.accelerators"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why expose this string?

If it was to avoid registering these by default, it would not work because to change the path you'd need to import the variable

And at import time, they would get registered anyways.


call_register_accelerators(ACCELERATORS_BASE_MODULE)
7 changes: 3 additions & 4 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def auto_device_count() -> int:
def is_available() -> bool:
"""Detect if the hardware is available."""

@staticmethod
@abstractmethod
def name() -> str:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't get the reasoning for removing the name.

If we kept it, the registry could use it to automatically define the name for the class

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Related discussion #12180 (comment)

"""Name of the Accelerator."""
@classmethod
def register_accelerators(cls, accelerator_registry: Dict) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typing is wrong. should be the actual accelerator registry

pass
11 changes: 7 additions & 4 deletions pytorch_lightning/accelerators/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,10 @@ def is_available() -> bool:
"""CPU is always available for execution."""
return True

@staticmethod
def name() -> str:
"""Name of the Accelerator."""
return "cpu"
@classmethod
def register_accelerators(cls, accelerator_registry: Dict) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The argument should be typed as _AcceleratorRegistry, not Dict

accelerator_registry.register(
"cpu",
cls,
description=f"{cls.__class__.__name__}",
)
11 changes: 7 additions & 4 deletions pytorch_lightning/accelerators/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,13 @@ def auto_device_count() -> int:
def is_available() -> bool:
return torch.cuda.device_count() > 0

@staticmethod
def name() -> str:
"""Name of the Accelerator."""
return "gpu"
@classmethod
def register_accelerators(cls, accelerator_registry: Dict) -> None:
accelerator_registry.register(
"gpu",
cls,
description=f"{cls.__class__.__name__}",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

__class__.__name__ is already a string so the fstring is redundant.

)


def get_nvidia_gpu_stats(device: _DEVICE) -> Dict[str, float]:
Expand Down
11 changes: 7 additions & 4 deletions pytorch_lightning/accelerators/ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ def auto_device_count() -> int:
def is_available() -> bool:
return _IPU_AVAILABLE

@staticmethod
def name() -> str:
"""Name of the Accelerator."""
return "ipu"
@classmethod
def register_accelerators(cls, accelerator_registry: Dict) -> None:
accelerator_registry.register(
"ipu",
cls,
description=f"{cls.__class__.__name__}",
)
122 changes: 122 additions & 0 deletions pytorch_lightning/accelerators/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
from inspect import getmembers, isclass
from typing import Any, Callable, Dict, List, Optional

from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.registry import _is_register_method_overridden


class _AcceleratorRegistry(dict):
"""This class is a Registry that stores information about the Accelerators.

The Accelerators are mapped to strings. These strings are names that identify
an accelerator, e.g., "gpu". It also returns Optional description and
parameters to initialize the Accelerator, which were defined during the
registration.
Comment on lines +27 to +29
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... an optional description ...


The motivation for having a AcceleratorRegistry is to make it convenient
for the Users to try different accelerators by passing mapped aliases
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

users capitalized

to the accelerator flag to the Trainer.

Example::

@AcceleratorRegistry.register("sota", description="Custom sota accelerator", a=1, b=True)
class SOTAAccelerator(Accelerator):
def __init__(self, a, b):
...
Comment on lines +37 to +40
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't work:

Traceback (most recent call last):
  File "/Users/carmocca/git/pytorch-lightning/thing.py", line 5, in <module>
    class SOTAAccelerator:
TypeError: do_register() missing 1 required positional argument: 'accelerator'


or

AcceleratorRegistry.register("sota", SOTAAccelerator, description="Custom sota accelerator", a=1, b=True)
"""

def register(
self,
name: str,
accelerator: Optional[Callable] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type should be Optional[Type], not Callable

description: str = "",
override: bool = False,
**init_params: Any,
) -> Callable:
"""Registers a accelerator mapped to a name and with required metadata.

Args:
name : the name that identifies a accelerator, e.g. "gpu"
accelerator : accelerator class
description : accelerator description
override : overrides the registered accelerator, if True
init_params: parameters to initialize the accelerator
"""
if not (name is None or isinstance(name, str)):
raise TypeError(f"`name` must be a str, found {name}")

if name in self and not override:
raise MisconfigurationException(f"'{name}' is already present in the registry. HINT: Use `override=True`.")

data: Dict[str, Any] = {}

data["description"] = description
data["init_params"] = init_params

def do_register(name: str, accelerator: Callable) -> Callable:
data["accelerator"] = accelerator
data["accelerator_name"] = name
self[name] = data
return accelerator

if accelerator is not None:
return do_register(name, accelerator)

return do_register

def get(self, name: str, default: Optional[Any] = None) -> Any:
"""Calls the registered accelerator with the required parameters and returns the accelerator object.

Args:
name (str): the name that identifies a accelerator, e.g. "gpu"
"""
if name in self:
data = self[name]
return data["accelerator"](**data["init_params"])

if default is not None:
return default

err_msg = "'{}' not found in registry. Available names: {}"
available_names = self.available_accelerators()
raise KeyError(err_msg.format(name, available_names))

def remove(self, name: str) -> None:
"""Removes the registered accelerator by name."""
self.pop(name)

def available_accelerators(self) -> List[str]:
"""Returns a list of registered accelerators."""
return list(self.keys())

def __str__(self) -> str:
return "Registered Accelerators: {}".format(", ".join(self.available_accelerators()))


AcceleratorRegistry = _AcceleratorRegistry()


def call_register_accelerators(base_module: str) -> None:
module = importlib.import_module(base_module)
for _, mod in getmembers(module, isclass):
if issubclass(mod, Accelerator) and _is_register_method_overridden(mod, Accelerator, "register_accelerators"):
mod.register_accelerators(AcceleratorRegistry)
11 changes: 7 additions & 4 deletions pytorch_lightning/accelerators/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,10 @@ def auto_device_count() -> int:
def is_available() -> bool:
return _TPU_AVAILABLE

@staticmethod
def name() -> str:
"""Name of the Accelerator."""
return "tpu"
@classmethod
def register_accelerators(cls, accelerator_registry: Dict) -> None:
accelerator_registry.register(
"tpu",
cls,
description=f"{cls.__class__.__name__}",
)
5 changes: 1 addition & 4 deletions pytorch_lightning/strategies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path

from pytorch_lightning.strategies.bagua import BaguaStrategy # noqa: F401
from pytorch_lightning.strategies.ddp import DDPStrategy # noqa: F401
from pytorch_lightning.strategies.ddp2 import DDP2Strategy # noqa: F401
Expand All @@ -31,7 +29,6 @@
from pytorch_lightning.strategies.strategy_registry import call_register_strategies, StrategyRegistry # noqa: F401
from pytorch_lightning.strategies.tpu_spawn import TPUSpawnStrategy # noqa: F401

FILE_ROOT = Path(__file__).parent
STRATEGIES_BASE_MODULE = "pytorch_lightning.strategies"

call_register_strategies(FILE_ROOT, STRATEGIES_BASE_MODULE)
call_register_strategies(STRATEGIES_BASE_MODULE)
21 changes: 3 additions & 18 deletions pytorch_lightning/strategies/strategy_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import inspect
from inspect import getmembers, isclass
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional

from pytorch_lightning.strategies.strategy import Strategy
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.registry import _is_register_method_overridden


class _StrategyRegistry(dict):
Expand Down Expand Up @@ -116,22 +115,8 @@ def __str__(self) -> str:
StrategyRegistry = _StrategyRegistry()


def is_register_strategies_overridden(strategy: type) -> bool:

method_name = "register_strategies"
strategy_attr = getattr(strategy, method_name)
previous_super_cls = inspect.getmro(strategy)[1]

if issubclass(previous_super_cls, Strategy):
super_attr = getattr(previous_super_cls, method_name)
else:
return False

return strategy_attr.__code__ is not super_attr.__code__


def call_register_strategies(root: Path, base_module: str) -> None:
def call_register_strategies(base_module: str) -> None:
module = importlib.import_module(base_module)
for _, mod in getmembers(module, isclass):
if issubclass(mod, Strategy) and is_register_strategies_overridden(mod):
if issubclass(mod, Strategy) and _is_register_method_overridden(mod, Strategy, "register_strategies"):
mod.register_strategies(StrategyRegistry)
26 changes: 11 additions & 15 deletions pytorch_lightning/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from pytorch_lightning.accelerators.cpu import CPUAccelerator
from pytorch_lightning.accelerators.gpu import GPUAccelerator
from pytorch_lightning.accelerators.ipu import IPUAccelerator
from pytorch_lightning.accelerators.registry import AcceleratorRegistry
from pytorch_lightning.accelerators.tpu import TPUAccelerator
from pytorch_lightning.plugins import (
ApexMixedPrecisionPlugin,
Expand Down Expand Up @@ -157,7 +158,7 @@ def __init__(
# 1. Parsing flags
# Get registered strategies, built-in accelerators and precision plugins
self._registered_strategies = StrategyRegistry.available_strategies()
self._accelerator_types = ("tpu", "ipu", "gpu", "cpu")
self._accelerator_types = AcceleratorRegistry.available_accelerators()
self._precision_types = ("16", "32", "64", "bf16", "mixed")

# Raise an exception if there are conflicts between flags
Expand Down Expand Up @@ -486,32 +487,27 @@ def _choose_accelerator(self) -> str:
return "cpu"

def _set_parallel_devices_and_init_accelerator(self) -> None:
ACCELERATORS = {
"cpu": CPUAccelerator,
"gpu": GPUAccelerator,
"tpu": TPUAccelerator,
"ipu": IPUAccelerator,
}
if isinstance(self._accelerator_flag, Accelerator):
self.accelerator: Accelerator = self._accelerator_flag
else:
assert self._accelerator_flag is not None
self._accelerator_flag = self._accelerator_flag.lower()
if self._accelerator_flag not in ACCELERATORS:
if self._accelerator_flag not in AcceleratorRegistry:
raise MisconfigurationException(
"When passing string value for the `accelerator` argument of `Trainer`,"
f" it can only be one of {list(ACCELERATORS)}."
f" it can only be one of {self._accelerator_types}."
)
accelerator_class = ACCELERATORS[self._accelerator_flag]
self.accelerator = accelerator_class() # type: ignore[abstract]
self.accelerator = AcceleratorRegistry.get(self._accelerator_flag)

if not self.accelerator.is_available():
available_accelerator = [acc_str for acc_str in list(ACCELERATORS) if ACCELERATORS[acc_str].is_available()]
available_accelerator = [
acc_str for acc_str in self._accelerator_types if AcceleratorRegistry.get(acc_str).is_available()
]
raise MisconfigurationException(
f"{self.accelerator.__class__.__qualname__} can not run on your system"
f" since {self.accelerator.name().upper()}s are not available."
" The following accelerator(s) is available and can be passed into"
f" `accelerator` argument of `Trainer`: {available_accelerator}."
" since the accelerator is not available. The following accelerator(s)"
" is available and can be passed into `accelerator` argument of"
f" `Trainer`: {available_accelerator}."
)

self._set_devices_flag_if_auto_passed()
Expand Down
27 changes: 27 additions & 0 deletions pytorch_lightning/utilities/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any


def _is_register_method_overridden(mod: type, base_cls: Any, method: str) -> bool:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use is_overridden?

mod_attr = getattr(mod, method)
previous_super_cls = inspect.getmro(mod)[1]

if issubclass(previous_super_cls, base_cls):
super_attr = getattr(previous_super_cls, method)
else:
return False

return mod_attr.__code__ is not super_attr.__code__
3 changes: 2 additions & 1 deletion tests/accelerators/test_accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,8 @@ def test_accelerator_cpu(_):
with pytest.raises(MisconfigurationException, match="You requested gpu:"):
trainer = Trainer(gpus=1)
with pytest.raises(
MisconfigurationException, match="GPUAccelerator can not run on your system since GPUs are not available."
MisconfigurationException,
match="GPUAccelerator can not run on your system since the accelerator is not available.",
):
trainer = Trainer(accelerator="gpu")
with pytest.raises(MisconfigurationException, match="You requested gpu:"):
Expand Down
Loading