Skip to content

Commit 832a03a

Browse files
kaushikb11Sean Narentchaton
authored
Add Training Type Plugins Registry (#6982)
Co-authored-by: Sean Naren <[email protected]> Co-authored-by: thomas chaton <[email protected]>
1 parent 67d2160 commit 832a03a

File tree

6 files changed

+276
-2
lines changed

6 files changed

+276
-2
lines changed

pytorch_lightning/plugins/__init__.py

+11
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
from pytorch_lightning.plugins.base_plugin import Plugin # noqa: F401
2+
from pytorch_lightning.plugins.plugins_registry import ( # noqa: F401
3+
call_training_type_register_plugins,
4+
TrainingTypePluginsRegistry,
5+
)
26
from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin # noqa: F401
37
from pytorch_lightning.plugins.precision.deepspeed_precision import DeepSpeedPrecisionPlugin # noqa: F401
48
from pytorch_lightning.plugins.precision.double import DoublePrecisionPlugin # noqa: F401
@@ -47,3 +51,10 @@
4751
'DDPShardedPlugin',
4852
'DDPSpawnShardedPlugin',
4953
]
54+
55+
from pathlib import Path
56+
57+
FILE_ROOT = Path(__file__).parent
58+
TRAINING_TYPE_BASE_MODULE = "pytorch_lightning.plugins.training_type"
59+
60+
call_training_type_register_plugins(FILE_ROOT, TRAINING_TYPE_BASE_MODULE)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import importlib
15+
import os
16+
from collections import UserDict
17+
from inspect import getmembers, isclass
18+
from pathlib import Path
19+
from typing import Any, Callable, List, Optional
20+
21+
from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin
22+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
23+
24+
25+
class _TrainingTypePluginsRegistry(UserDict):
26+
"""
27+
This class is a Registry that stores information about the Training Type Plugins.
28+
29+
The Plugins are mapped to strings. These strings are names that idenitify
30+
a plugin, e.g., "deepspeed". It also returns Optional description and
31+
parameters to initialize the Plugin, which were defined durng the
32+
registration.
33+
34+
The motivation for having a TrainingTypePluginRegistry is to make it convenient
35+
for the Users to try different Plugins by passing just strings
36+
to the plugins flag to the Trainer.
37+
38+
Example::
39+
40+
@TrainingTypePluginsRegistry.register("lightning", description="Super fast", a=1, b=True)
41+
class LightningPlugin:
42+
def __init__(self, a, b):
43+
...
44+
45+
or
46+
47+
TrainingTypePluginsRegistry.register("lightning", LightningPlugin, description="Super fast", a=1, b=True)
48+
49+
"""
50+
51+
def register(
52+
self,
53+
name: str,
54+
plugin: Optional[Callable] = None,
55+
description: Optional[str] = None,
56+
override: bool = False,
57+
**init_params: Any,
58+
) -> Callable:
59+
"""
60+
Registers a plugin mapped to a name and with required metadata.
61+
62+
Args:
63+
name : the name that identifies a plugin, e.g. "deepspeed_stage_3"
64+
plugin : plugin class
65+
description : plugin description
66+
override : overrides the registered plugin, if True
67+
init_params: parameters to initialize the plugin
68+
"""
69+
if not (name is None or isinstance(name, str)):
70+
raise TypeError(f'`name` must be a str, found {name}')
71+
72+
if name in self and not override:
73+
raise MisconfigurationException(
74+
f"'{name}' is already present in the registry."
75+
" HINT: Use `override=True`."
76+
)
77+
78+
data = {}
79+
data["description"] = description if description is not None else ""
80+
81+
data["init_params"] = init_params
82+
83+
def do_register(plugin: Callable) -> Callable:
84+
data["plugin"] = plugin
85+
self[name] = data
86+
return plugin
87+
88+
if plugin is not None:
89+
return do_register(plugin)
90+
91+
return do_register
92+
93+
def get(self, name: str) -> Any:
94+
"""
95+
Calls the registered plugin with the required parameters
96+
and returns the plugin object
97+
98+
Args:
99+
name (str): the name that identifies a plugin, e.g. "deepspeed_stage_3"
100+
"""
101+
if name in self:
102+
data = self[name]
103+
return data["plugin"](**data["init_params"])
104+
105+
err_msg = "'{}' not found in registry. Available names: {}"
106+
available_names = ", ".join(sorted(self.keys())) or "none"
107+
raise KeyError(err_msg.format(name, available_names))
108+
109+
def remove(self, name: str) -> None:
110+
"""Removes the registered plugin by name"""
111+
self.pop(name)
112+
113+
def available_plugins(self) -> List:
114+
"""Returns a list of registered plugins"""
115+
return list(self.keys())
116+
117+
def __str__(self) -> str:
118+
return "Registered Plugins: {}".format(", ".join(self.keys()))
119+
120+
121+
TrainingTypePluginsRegistry = _TrainingTypePluginsRegistry()
122+
123+
124+
def is_register_plugins_overridden(plugin: Callable) -> bool:
125+
method_name = "register_plugins"
126+
plugin_attr = getattr(plugin, method_name)
127+
super_attr = getattr(TrainingTypePlugin, method_name)
128+
129+
if hasattr(plugin_attr, 'patch_loader_code'):
130+
is_overridden = plugin_attr.patch_loader_code != str(super_attr.__code__)
131+
else:
132+
is_overridden = plugin_attr.__code__ is not super_attr.__code__
133+
return is_overridden
134+
135+
136+
def call_training_type_register_plugins(root: Path, base_module: str) -> None:
137+
# Ref: https://github.com/facebookresearch/ClassyVision/blob/master/classy_vision/generic/registry_utils.py#L14
138+
directory = "training_type"
139+
for file in os.listdir(root / directory):
140+
if file.endswith(".py") and not file.startswith("_"):
141+
module = file[:file.find(".py")]
142+
module = importlib.import_module(".".join([base_module, module]))
143+
for _, mod in getmembers(module, isclass):
144+
if issubclass(mod, TrainingTypePlugin) and is_register_plugins_overridden(mod):
145+
mod.register_plugins(TrainingTypePluginsRegistry)
146+
break

pytorch_lightning/plugins/training_type/deepspeed.py

+20
Original file line numberDiff line numberDiff line change
@@ -542,3 +542,23 @@ def update_global_step(self, total_batch_idx: int, current_global_step: int) ->
542542
if total_batch_idx % self._original_accumulate_grad_batches == 0:
543543
current_global_step += 1
544544
return current_global_step
545+
546+
@classmethod
547+
def register_plugins(cls, plugin_registry):
548+
plugin_registry.register("deepspeed", cls, description="Default DeepSpeed Plugin")
549+
plugin_registry.register("deepspeed_stage_2", cls, description="DeepSpeed with ZeRO Stage 2 enabled", stage=2)
550+
plugin_registry.register(
551+
"deepspeed_stage_2_offload",
552+
cls,
553+
description="DeepSpeed ZeRO Stage 2 and CPU Offload",
554+
stage=2,
555+
cpu_offload=True
556+
)
557+
plugin_registry.register("deepspeed_stage_3", cls, description="DeepSpeed ZeRO Stage 3", stage=3)
558+
plugin_registry.register(
559+
"deepspeed_stage_3_offload",
560+
cls,
561+
description="DeepSpeed ZeRO Stage 3 and CPU Offload",
562+
stage=3,
563+
cpu_offload=True
564+
)

pytorch_lightning/plugins/training_type/training_type_plugin.py

+4
Original file line numberDiff line numberDiff line change
@@ -286,3 +286,7 @@ def call_configure_sharded_model_hook(self) -> bool:
286286
@call_configure_sharded_model_hook.setter
287287
def call_configure_sharded_model_hook(self, mode: bool) -> None:
288288
self._call_configure_sharded_model_hook = mode
289+
290+
@classmethod
291+
def register_plugins(cls, plugin_registry):
292+
pass

pytorch_lightning/trainer/connectors/accelerator_connector.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
TPUHalfPrecisionPlugin,
4343
TPUSpawnPlugin,
4444
TrainingTypePlugin,
45+
TrainingTypePluginsRegistry,
4546
)
4647
from pytorch_lightning.plugins.environments import (
4748
ClusterEnvironment,
@@ -163,7 +164,16 @@ def handle_given_plugins(
163164
cluster_environment = None
164165

165166
for plug in plugins:
166-
if isinstance(plug, str):
167+
if isinstance(plug, str) and plug in TrainingTypePluginsRegistry:
168+
if training_type is None:
169+
training_type = TrainingTypePluginsRegistry.get(plug)
170+
else:
171+
raise MisconfigurationException(
172+
'You can only specify one precision and one training type plugin.'
173+
' Found more than 1 training type plugin:'
174+
f' {TrainingTypePluginsRegistry[plug]["plugin"]} registered to {plug}'
175+
)
176+
elif isinstance(plug, str):
167177
# Reset the distributed type as the user has overridden training type
168178
# via the plugins argument
169179
self._distrib_type = None
@@ -530,7 +540,7 @@ def set_distributed_mode(self, distributed_backend: Optional[str] = None):
530540
rank_zero_warn(
531541
'You requested distributed training on GPUs, but none is available, so we set backend to `ddp_cpu`.'
532542
)
533-
# todo: in some cases it yield in comarison None and int
543+
# todo: in some cases it yield in comparison None and int
534544
if (self.num_nodes and self.num_nodes > 1) or (self.num_processes and self.num_processes > 1):
535545
self._distrib_type = DistributedType.DDP
536546
else:
+83
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import pytest
15+
16+
from pytorch_lightning import Trainer
17+
from pytorch_lightning.plugins import TrainingTypePluginsRegistry
18+
from pytorch_lightning.plugins.training_type.deepspeed import DeepSpeedPlugin
19+
from tests.helpers.runif import RunIf
20+
21+
22+
def test_training_type_plugins_registry_with_new_plugin():
23+
24+
class TestPlugin:
25+
26+
def __init__(self, param1, param2):
27+
self.param1 = param1
28+
self.param2 = param2
29+
30+
plugin_name = "test_plugin"
31+
plugin_description = "Test Plugin"
32+
33+
TrainingTypePluginsRegistry.register(
34+
plugin_name, TestPlugin, description=plugin_description, param1="abc", param2=123
35+
)
36+
37+
assert plugin_name in TrainingTypePluginsRegistry
38+
assert TrainingTypePluginsRegistry[plugin_name]["description"] == plugin_description
39+
assert TrainingTypePluginsRegistry[plugin_name]["init_params"] == {"param1": "abc", "param2": 123}
40+
assert isinstance(TrainingTypePluginsRegistry.get(plugin_name), TestPlugin)
41+
42+
TrainingTypePluginsRegistry.remove(plugin_name)
43+
assert plugin_name not in TrainingTypePluginsRegistry
44+
45+
46+
@pytest.mark.parametrize(
47+
"plugin_name, init_params",
48+
[
49+
("deepspeed", {}),
50+
("deepspeed_stage_2", {
51+
"stage": 2
52+
}),
53+
("deepspeed_stage_2_offload", {
54+
"stage": 2,
55+
"cpu_offload": True
56+
}),
57+
("deepspeed_stage_3", {
58+
"stage": 3
59+
}),
60+
("deepspeed_stage_3_offload", {
61+
"stage": 3,
62+
"cpu_offload": True
63+
}),
64+
],
65+
)
66+
def test_training_type_plugins_registry_with_deepspeed_plugins(plugin_name, init_params):
67+
68+
assert plugin_name in TrainingTypePluginsRegistry
69+
assert TrainingTypePluginsRegistry[plugin_name]["init_params"] == init_params
70+
assert TrainingTypePluginsRegistry[plugin_name]["plugin"] == DeepSpeedPlugin
71+
72+
73+
@RunIf(deepspeed=True)
74+
@pytest.mark.parametrize("plugin", ["deepspeed", "deepspeed_stage_2_offload", "deepspeed_stage_3"])
75+
def test_training_type_plugins_registry_with_trainer(tmpdir, plugin):
76+
77+
trainer = Trainer(
78+
default_root_dir=tmpdir,
79+
plugins=plugin,
80+
precision=16,
81+
)
82+
83+
assert isinstance(trainer.training_type_plugin, DeepSpeedPlugin)

0 commit comments

Comments
 (0)