Skip to content

Commit dcc973e

Browse files
kaushikb11rohitgr7ananthsub
authored
Add AcceleratorRegistry (#12180)
Co-authored-by: Rohit Gupta <[email protected]> Co-authored-by: ananthsub <[email protected]>
1 parent c099c8b commit dcc973e

File tree

14 files changed

+271
-58
lines changed

14 files changed

+271
-58
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
164164
- Added `Callback.state_dict()` and `Callback.load_state_dict()` methods ([#12232](https://github.com/PyTorchLightning/pytorch-lightning/pull/12232))
165165

166166

167+
- Added `AcceleratorRegistry` ([#12180](https://github.com/PyTorchLightning/pytorch-lightning/pull/12180))
168+
169+
167170
### Changed
168171

169172
- Drop PyTorch 1.7 support ([#12191](https://github.com/PyTorchLightning/pytorch-lightning/pull/12191))

pytorch_lightning/accelerators/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,9 @@
1414
from pytorch_lightning.accelerators.cpu import CPUAccelerator # noqa: F401
1515
from pytorch_lightning.accelerators.gpu import GPUAccelerator # noqa: F401
1616
from pytorch_lightning.accelerators.ipu import IPUAccelerator # noqa: F401
17+
from pytorch_lightning.accelerators.registry import AcceleratorRegistry, call_register_accelerators # noqa: F401
1718
from pytorch_lightning.accelerators.tpu import TPUAccelerator # noqa: F401
19+
20+
ACCELERATORS_BASE_MODULE = "pytorch_lightning.accelerators"
21+
22+
call_register_accelerators(ACCELERATORS_BASE_MODULE)

pytorch_lightning/accelerators/accelerator.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ def auto_device_count() -> int:
7575
def is_available() -> bool:
7676
"""Detect if the hardware is available."""
7777

78-
@staticmethod
79-
@abstractmethod
80-
def name() -> str:
81-
"""Name of the Accelerator."""
78+
@classmethod
79+
def register_accelerators(cls, accelerator_registry: Dict) -> None:
80+
pass

pytorch_lightning/accelerators/cpu.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,10 @@ def is_available() -> bool:
6363
"""CPU is always available for execution."""
6464
return True
6565

66-
@staticmethod
67-
def name() -> str:
68-
"""Name of the Accelerator."""
69-
return "cpu"
66+
@classmethod
67+
def register_accelerators(cls, accelerator_registry: Dict) -> None:
68+
accelerator_registry.register(
69+
"cpu",
70+
cls,
71+
description=f"{cls.__class__.__name__}",
72+
)

pytorch_lightning/accelerators/gpu.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,13 @@ def auto_device_count() -> int:
9393
def is_available() -> bool:
9494
return torch.cuda.device_count() > 0
9595

96-
@staticmethod
97-
def name() -> str:
98-
"""Name of the Accelerator."""
99-
return "gpu"
96+
@classmethod
97+
def register_accelerators(cls, accelerator_registry: Dict) -> None:
98+
accelerator_registry.register(
99+
"gpu",
100+
cls,
101+
description=f"{cls.__class__.__name__}",
102+
)
100103

101104

102105
def get_nvidia_gpu_stats(device: _DEVICE) -> Dict[str, float]:

pytorch_lightning/accelerators/ipu.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,10 @@ def auto_device_count() -> int:
4747
def is_available() -> bool:
4848
return _IPU_AVAILABLE
4949

50-
@staticmethod
51-
def name() -> str:
52-
"""Name of the Accelerator."""
53-
return "ipu"
50+
@classmethod
51+
def register_accelerators(cls, accelerator_registry: Dict) -> None:
52+
accelerator_registry.register(
53+
"ipu",
54+
cls,
55+
description=f"{cls.__class__.__name__}",
56+
)
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import importlib
15+
from inspect import getmembers, isclass
16+
from typing import Any, Callable, Dict, List, Optional
17+
18+
from pytorch_lightning.accelerators.accelerator import Accelerator
19+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
20+
from pytorch_lightning.utilities.registry import _is_register_method_overridden
21+
22+
23+
class _AcceleratorRegistry(dict):
24+
"""This class is a Registry that stores information about the Accelerators.
25+
26+
The Accelerators are mapped to strings. These strings are names that identify
27+
an accelerator, e.g., "gpu". It also returns Optional description and
28+
parameters to initialize the Accelerator, which were defined during the
29+
registration.
30+
31+
The motivation for having a AcceleratorRegistry is to make it convenient
32+
for the Users to try different accelerators by passing mapped aliases
33+
to the accelerator flag to the Trainer.
34+
35+
Example::
36+
37+
@AcceleratorRegistry.register("sota", description="Custom sota accelerator", a=1, b=True)
38+
class SOTAAccelerator(Accelerator):
39+
def __init__(self, a, b):
40+
...
41+
42+
or
43+
44+
AcceleratorRegistry.register("sota", SOTAAccelerator, description="Custom sota accelerator", a=1, b=True)
45+
"""
46+
47+
def register(
48+
self,
49+
name: str,
50+
accelerator: Optional[Callable] = None,
51+
description: str = "",
52+
override: bool = False,
53+
**init_params: Any,
54+
) -> Callable:
55+
"""Registers a accelerator mapped to a name and with required metadata.
56+
57+
Args:
58+
name : the name that identifies a accelerator, e.g. "gpu"
59+
accelerator : accelerator class
60+
description : accelerator description
61+
override : overrides the registered accelerator, if True
62+
init_params: parameters to initialize the accelerator
63+
"""
64+
if not (name is None or isinstance(name, str)):
65+
raise TypeError(f"`name` must be a str, found {name}")
66+
67+
if name in self and not override:
68+
raise MisconfigurationException(f"'{name}' is already present in the registry. HINT: Use `override=True`.")
69+
70+
data: Dict[str, Any] = {}
71+
72+
data["description"] = description
73+
data["init_params"] = init_params
74+
75+
def do_register(name: str, accelerator: Callable) -> Callable:
76+
data["accelerator"] = accelerator
77+
data["accelerator_name"] = name
78+
self[name] = data
79+
return accelerator
80+
81+
if accelerator is not None:
82+
return do_register(name, accelerator)
83+
84+
return do_register
85+
86+
def get(self, name: str, default: Optional[Any] = None) -> Any:
87+
"""Calls the registered accelerator with the required parameters and returns the accelerator object.
88+
89+
Args:
90+
name (str): the name that identifies a accelerator, e.g. "gpu"
91+
"""
92+
if name in self:
93+
data = self[name]
94+
return data["accelerator"](**data["init_params"])
95+
96+
if default is not None:
97+
return default
98+
99+
err_msg = "'{}' not found in registry. Available names: {}"
100+
available_names = self.available_accelerators()
101+
raise KeyError(err_msg.format(name, available_names))
102+
103+
def remove(self, name: str) -> None:
104+
"""Removes the registered accelerator by name."""
105+
self.pop(name)
106+
107+
def available_accelerators(self) -> List[str]:
108+
"""Returns a list of registered accelerators."""
109+
return list(self.keys())
110+
111+
def __str__(self) -> str:
112+
return "Registered Accelerators: {}".format(", ".join(self.available_accelerators()))
113+
114+
115+
AcceleratorRegistry = _AcceleratorRegistry()
116+
117+
118+
def call_register_accelerators(base_module: str) -> None:
119+
module = importlib.import_module(base_module)
120+
for _, mod in getmembers(module, isclass):
121+
if issubclass(mod, Accelerator) and _is_register_method_overridden(mod, Accelerator, "register_accelerators"):
122+
mod.register_accelerators(AcceleratorRegistry)

pytorch_lightning/accelerators/tpu.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,10 @@ def auto_device_count() -> int:
6565
def is_available() -> bool:
6666
return _TPU_AVAILABLE
6767

68-
@staticmethod
69-
def name() -> str:
70-
"""Name of the Accelerator."""
71-
return "tpu"
68+
@classmethod
69+
def register_accelerators(cls, accelerator_registry: Dict) -> None:
70+
accelerator_registry.register(
71+
"tpu",
72+
cls,
73+
description=f"{cls.__class__.__name__}",
74+
)

pytorch_lightning/strategies/__init__.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from pathlib import Path
15-
1614
from pytorch_lightning.strategies.bagua import BaguaStrategy # noqa: F401
1715
from pytorch_lightning.strategies.ddp import DDPStrategy # noqa: F401
1816
from pytorch_lightning.strategies.ddp2 import DDP2Strategy # noqa: F401
@@ -31,7 +29,6 @@
3129
from pytorch_lightning.strategies.strategy_registry import call_register_strategies, StrategyRegistry # noqa: F401
3230
from pytorch_lightning.strategies.tpu_spawn import TPUSpawnStrategy # noqa: F401
3331

34-
FILE_ROOT = Path(__file__).parent
3532
STRATEGIES_BASE_MODULE = "pytorch_lightning.strategies"
3633

37-
call_register_strategies(FILE_ROOT, STRATEGIES_BASE_MODULE)
34+
call_register_strategies(STRATEGIES_BASE_MODULE)

pytorch_lightning/strategies/strategy_registry.py

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import importlib
15-
import inspect
1615
from inspect import getmembers, isclass
17-
from pathlib import Path
1816
from typing import Any, Callable, Dict, List, Optional
1917

2018
from pytorch_lightning.strategies.strategy import Strategy
2119
from pytorch_lightning.utilities.exceptions import MisconfigurationException
20+
from pytorch_lightning.utilities.registry import _is_register_method_overridden
2221

2322

2423
class _StrategyRegistry(dict):
@@ -116,22 +115,8 @@ def __str__(self) -> str:
116115
StrategyRegistry = _StrategyRegistry()
117116

118117

119-
def is_register_strategies_overridden(strategy: type) -> bool:
120-
121-
method_name = "register_strategies"
122-
strategy_attr = getattr(strategy, method_name)
123-
previous_super_cls = inspect.getmro(strategy)[1]
124-
125-
if issubclass(previous_super_cls, Strategy):
126-
super_attr = getattr(previous_super_cls, method_name)
127-
else:
128-
return False
129-
130-
return strategy_attr.__code__ is not super_attr.__code__
131-
132-
133-
def call_register_strategies(root: Path, base_module: str) -> None:
118+
def call_register_strategies(base_module: str) -> None:
134119
module = importlib.import_module(base_module)
135120
for _, mod in getmembers(module, isclass):
136-
if issubclass(mod, Strategy) and is_register_strategies_overridden(mod):
121+
if issubclass(mod, Strategy) and _is_register_method_overridden(mod, Strategy, "register_strategies"):
137122
mod.register_strategies(StrategyRegistry)

pytorch_lightning/trainer/connectors/accelerator_connector.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from pytorch_lightning.accelerators.cpu import CPUAccelerator
2424
from pytorch_lightning.accelerators.gpu import GPUAccelerator
2525
from pytorch_lightning.accelerators.ipu import IPUAccelerator
26+
from pytorch_lightning.accelerators.registry import AcceleratorRegistry
2627
from pytorch_lightning.accelerators.tpu import TPUAccelerator
2728
from pytorch_lightning.plugins import (
2829
ApexMixedPrecisionPlugin,
@@ -157,7 +158,7 @@ def __init__(
157158
# 1. Parsing flags
158159
# Get registered strategies, built-in accelerators and precision plugins
159160
self._registered_strategies = StrategyRegistry.available_strategies()
160-
self._accelerator_types = ("tpu", "ipu", "gpu", "cpu")
161+
self._accelerator_types = AcceleratorRegistry.available_accelerators()
161162
self._precision_types = ("16", "32", "64", "bf16", "mixed")
162163

163164
# Raise an exception if there are conflicts between flags
@@ -486,32 +487,27 @@ def _choose_accelerator(self) -> str:
486487
return "cpu"
487488

488489
def _set_parallel_devices_and_init_accelerator(self) -> None:
489-
ACCELERATORS = {
490-
"cpu": CPUAccelerator,
491-
"gpu": GPUAccelerator,
492-
"tpu": TPUAccelerator,
493-
"ipu": IPUAccelerator,
494-
}
495490
if isinstance(self._accelerator_flag, Accelerator):
496491
self.accelerator: Accelerator = self._accelerator_flag
497492
else:
498493
assert self._accelerator_flag is not None
499494
self._accelerator_flag = self._accelerator_flag.lower()
500-
if self._accelerator_flag not in ACCELERATORS:
495+
if self._accelerator_flag not in AcceleratorRegistry:
501496
raise MisconfigurationException(
502497
"When passing string value for the `accelerator` argument of `Trainer`,"
503-
f" it can only be one of {list(ACCELERATORS)}."
498+
f" it can only be one of {self._accelerator_types}."
504499
)
505-
accelerator_class = ACCELERATORS[self._accelerator_flag]
506-
self.accelerator = accelerator_class() # type: ignore[abstract]
500+
self.accelerator = AcceleratorRegistry.get(self._accelerator_flag)
507501

508502
if not self.accelerator.is_available():
509-
available_accelerator = [acc_str for acc_str in list(ACCELERATORS) if ACCELERATORS[acc_str].is_available()]
503+
available_accelerator = [
504+
acc_str for acc_str in self._accelerator_types if AcceleratorRegistry.get(acc_str).is_available()
505+
]
510506
raise MisconfigurationException(
511507
f"{self.accelerator.__class__.__qualname__} can not run on your system"
512-
f" since {self.accelerator.name().upper()}s are not available."
513-
" The following accelerator(s) is available and can be passed into"
514-
f" `accelerator` argument of `Trainer`: {available_accelerator}."
508+
" since the accelerator is not available. The following accelerator(s)"
509+
" is available and can be passed into `accelerator` argument of"
510+
f" `Trainer`: {available_accelerator}."
515511
)
516512

517513
self._set_devices_flag_if_auto_passed()
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import inspect
15+
from typing import Any
16+
17+
18+
def _is_register_method_overridden(mod: type, base_cls: Any, method: str) -> bool:
19+
mod_attr = getattr(mod, method)
20+
previous_super_cls = inspect.getmro(mod)[1]
21+
22+
if issubclass(previous_super_cls, base_cls):
23+
super_attr = getattr(previous_super_cls, method)
24+
else:
25+
return False
26+
27+
return mod_attr.__code__ is not super_attr.__code__

tests/accelerators/test_accelerator_connector.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -498,7 +498,8 @@ def test_accelerator_cpu(_):
498498
with pytest.raises(MisconfigurationException, match="You requested gpu:"):
499499
trainer = Trainer(gpus=1)
500500
with pytest.raises(
501-
MisconfigurationException, match="GPUAccelerator can not run on your system since GPUs are not available."
501+
MisconfigurationException,
502+
match="GPUAccelerator can not run on your system since the accelerator is not available.",
502503
):
503504
trainer = Trainer(accelerator="gpu")
504505
with pytest.raises(MisconfigurationException, match="You requested gpu:"):

0 commit comments

Comments
 (0)