Skip to content

Add Training Type Plugins Registry #6982

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Apr 16, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions pytorch_lightning/plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pytorch_lightning.plugins.base_plugin import Plugin # noqa: F401
from pytorch_lightning.plugins.plugins_registry import TrainingTypePluginsRegistry
from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.deepspeed_precision import DeepSpeedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.double import DoublePrecisionPlugin # noqa: F401
Expand Down Expand Up @@ -47,3 +48,39 @@
'DDPShardedPlugin',
'DDPSpawnShardedPlugin',
]

REGISTRY_TRAINING_TYPE_PLUGINS = [{
"plugin": DeepSpeedPlugin,
"variants": [{
"name": "deepspeed",
"description": "Default DeepSpeed Plugin"
}, {
"name": "deepspeed_stage_2",
"description": "DeepSpeed with ZeRO Stage 2 enabled",
"stage": 2
}, {
"name": "deepspeed_stage_2_offload",
"description": "DeepSpeed with ZeRO Stage 2 enabled and Offload",
"stage": 2,
"cpu_offload": True
}, {
"name": "deepspeed_stage_3",
"description": "DeepSpeed with ZeRO Stage 3 enabled",
"stage": 3
}, {
"name": "deepspeed_stage_3_offload",
"description": "DeepSpeed with ZeRO Stage 2 enabled and Offload",
"stage": 3,
"cpu_offload": True
}]
}]


def register_training_type_plugins(plugins):
for plugin_info in plugins:
plugin = plugin_info["plugin"]
for variant in plugin_info["variants"]:
TrainingTypePluginsRegistry.register(**variant)


register_training_type_plugins(REGISTRY_TRAINING_TYPE_PLUGINS)
87 changes: 87 additions & 0 deletions pytorch_lightning/plugins/plugins_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import UserDict
from typing import Callable, List, Optional

from pytorch_lightning.utilities.exceptions import MisconfigurationException


class _TrainingTypePluginsRegistry(UserDict):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we maybe make this a general registry and reuse this with flash? We also have a registry in flash and duplicating this does not make sense...

I think you only have to change the naming of the fields...

cc @tchaton

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did think of a generic LightningRegistry that could be used for different types. But for now, there has only been a need for TrainingTypePlugins, we could change it down the road as well, as _TrainingTypePluginsRegistry is internal.

"""
This class is a Registry that stores information about the Training Type Plugins.

The Plugins are mapped to strings. These strings are names that idenitify
a plugin, eg., "deepspeed". It also returns Optional description and
parameters to initialize the Plugin, which were defined durng the
registeration.

The motivation for having a PluginRegistry is to make it convenient
for the Users to try different Plugins by passing just strings
to the plugins flag to the Trainer.

"""

def register(
self,
name: str,
func: Optional[Callable] = None,
description: Optional[str] = None,
override: bool = False,
**init_params
):

if not (name is None or isinstance(name, str)):
raise TypeError(f'`name` must be a str, found {name}')

if name in self and not override:
raise MisconfigurationException(
f"'{name}' is already present in the registry."
" HINT: Use `override=True`."
)

data = {}
data["description"] = description if description is not None else ""

data["init_params"] = init_params

def do_register(func):
data["func"] = func
self[name] = data
return data

if func is not None:
return do_register(func)

return do_register

def get(self, name: str):
if name in self:
data = self[name]
return data["func"](**data["init_params"])

err_msg = "'{}' not found in registry. Available names: {}"
available_names = ", ".join(sorted(self.keys())) or "none"
raise KeyError(err_msg.format(name, available_names))

def remove(self, name: str):
self.pop(name)

def available_plugins(self) -> List:
return list(self.keys())

def __str__(self):
return "Registered Plugins: {}".format(", ".join(self.keys()))


TrainingTypePluginsRegistry = _TrainingTypePluginsRegistry()