-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Add AcceleratorRegistry
#12180
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add AcceleratorRegistry
#12180
Changes from all commits
4f22ae1
ec17fbe
25aca8a
ad81525
ad3c87e
84379ef
5d736d9
4206198
f7f5eda
ef188a1
eb22e01
cdb0572
c8ea7f2
8c96d9c
d18091d
a5925d9
75884c6
197dbaa
b0e1120
d1de49e
096afb9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -75,7 +75,6 @@ def auto_device_count() -> int: | |
def is_available() -> bool: | ||
"""Detect if the hardware is available.""" | ||
|
||
@staticmethod | ||
@abstractmethod | ||
def name() -> str: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't get the reasoning for removing the If we kept it, the registry could use it to automatically define the name for the class There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Related discussion #12180 (comment) |
||
"""Name of the Accelerator.""" | ||
@classmethod | ||
def register_accelerators(cls, accelerator_registry: Dict) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Typing is wrong. should be the actual accelerator registry |
||
pass |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -63,7 +63,10 @@ def is_available() -> bool: | |
"""CPU is always available for execution.""" | ||
return True | ||
|
||
@staticmethod | ||
def name() -> str: | ||
"""Name of the Accelerator.""" | ||
return "cpu" | ||
@classmethod | ||
def register_accelerators(cls, accelerator_registry: Dict) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The argument should be typed as |
||
accelerator_registry.register( | ||
"cpu", | ||
cls, | ||
description=f"{cls.__class__.__name__}", | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -93,10 +93,13 @@ def auto_device_count() -> int: | |
def is_available() -> bool: | ||
return torch.cuda.device_count() > 0 | ||
|
||
@staticmethod | ||
def name() -> str: | ||
"""Name of the Accelerator.""" | ||
return "gpu" | ||
@classmethod | ||
def register_accelerators(cls, accelerator_registry: Dict) -> None: | ||
accelerator_registry.register( | ||
"gpu", | ||
cls, | ||
description=f"{cls.__class__.__name__}", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
) | ||
|
||
|
||
def get_nvidia_gpu_stats(device: _DEVICE) -> Dict[str, float]: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import importlib | ||
from inspect import getmembers, isclass | ||
from typing import Any, Callable, Dict, List, Optional | ||
|
||
from pytorch_lightning.accelerators.accelerator import Accelerator | ||
from pytorch_lightning.utilities.exceptions import MisconfigurationException | ||
from pytorch_lightning.utilities.registry import _is_register_method_overridden | ||
|
||
|
||
class _AcceleratorRegistry(dict): | ||
"""This class is a Registry that stores information about the Accelerators. | ||
|
||
The Accelerators are mapped to strings. These strings are names that identify | ||
an accelerator, e.g., "gpu". It also returns Optional description and | ||
parameters to initialize the Accelerator, which were defined during the | ||
registration. | ||
Comment on lines
+27
to
+29
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ... an optional description ... |
||
|
||
The motivation for having a AcceleratorRegistry is to make it convenient | ||
for the Users to try different accelerators by passing mapped aliases | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. users capitalized |
||
to the accelerator flag to the Trainer. | ||
|
||
Example:: | ||
|
||
@AcceleratorRegistry.register("sota", description="Custom sota accelerator", a=1, b=True) | ||
class SOTAAccelerator(Accelerator): | ||
def __init__(self, a, b): | ||
... | ||
Comment on lines
+37
to
+40
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This doesn't work: Traceback (most recent call last):
File "/Users/carmocca/git/pytorch-lightning/thing.py", line 5, in <module>
class SOTAAccelerator:
TypeError: do_register() missing 1 required positional argument: 'accelerator' |
||
|
||
or | ||
|
||
AcceleratorRegistry.register("sota", SOTAAccelerator, description="Custom sota accelerator", a=1, b=True) | ||
""" | ||
|
||
def register( | ||
self, | ||
name: str, | ||
accelerator: Optional[Callable] = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The type should be |
||
description: str = "", | ||
override: bool = False, | ||
**init_params: Any, | ||
) -> Callable: | ||
"""Registers a accelerator mapped to a name and with required metadata. | ||
|
||
Args: | ||
name : the name that identifies a accelerator, e.g. "gpu" | ||
accelerator : accelerator class | ||
description : accelerator description | ||
override : overrides the registered accelerator, if True | ||
init_params: parameters to initialize the accelerator | ||
kaushikb11 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
if not (name is None or isinstance(name, str)): | ||
raise TypeError(f"`name` must be a str, found {name}") | ||
|
||
if name in self and not override: | ||
kaushikb11 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
raise MisconfigurationException(f"'{name}' is already present in the registry. HINT: Use `override=True`.") | ||
|
||
data: Dict[str, Any] = {} | ||
|
||
data["description"] = description | ||
data["init_params"] = init_params | ||
|
||
def do_register(name: str, accelerator: Callable) -> Callable: | ||
data["accelerator"] = accelerator | ||
data["accelerator_name"] = name | ||
self[name] = data | ||
return accelerator | ||
|
||
if accelerator is not None: | ||
return do_register(name, accelerator) | ||
|
||
return do_register | ||
|
||
def get(self, name: str, default: Optional[Any] = None) -> Any: | ||
"""Calls the registered accelerator with the required parameters and returns the accelerator object. | ||
|
||
Args: | ||
name (str): the name that identifies a accelerator, e.g. "gpu" | ||
""" | ||
if name in self: | ||
data = self[name] | ||
return data["accelerator"](**data["init_params"]) | ||
|
||
if default is not None: | ||
return default | ||
|
||
err_msg = "'{}' not found in registry. Available names: {}" | ||
available_names = self.available_accelerators() | ||
raise KeyError(err_msg.format(name, available_names)) | ||
|
||
def remove(self, name: str) -> None: | ||
"""Removes the registered accelerator by name.""" | ||
self.pop(name) | ||
|
||
def available_accelerators(self) -> List[str]: | ||
"""Returns a list of registered accelerators.""" | ||
return list(self.keys()) | ||
|
||
def __str__(self) -> str: | ||
return "Registered Accelerators: {}".format(", ".join(self.available_accelerators())) | ||
|
||
|
||
AcceleratorRegistry = _AcceleratorRegistry() | ||
|
||
|
||
def call_register_accelerators(base_module: str) -> None: | ||
kaushikb11 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
module = importlib.import_module(base_module) | ||
for _, mod in getmembers(module, isclass): | ||
if issubclass(mod, Accelerator) and _is_register_method_overridden(mod, Accelerator, "register_accelerators"): | ||
mod.register_accelerators(AcceleratorRegistry) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import inspect | ||
from typing import Any | ||
|
||
|
||
def _is_register_method_overridden(mod: type, base_cls: Any, method: str) -> bool: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not use |
||
mod_attr = getattr(mod, method) | ||
previous_super_cls = inspect.getmro(mod)[1] | ||
|
||
if issubclass(previous_super_cls, base_cls): | ||
super_attr = getattr(previous_super_cls, method) | ||
else: | ||
return False | ||
|
||
return mod_attr.__code__ is not super_attr.__code__ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why expose this string?
If it was to avoid registering these by default, it would not work because to change the path you'd need to import the variable
And at import time, they would get registered anyways.